import numpy as np
import torch
import os
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from functorch.experimental import replace_all_batch_norm_modules_
import clip
from pydvl.influence.torch import (
    DirectInfluence,
    CgInfluence,
    ArnoldiInfluence,
    EkfacInfluence,
    NystroemSketchInfluence,
    LissaInfluence,
)
from pydvl.influence.torch.util import (
    NestedTorchCatAggregator, 
    TorchNumpyConverter,
)
from pydvl.influence import InfluenceMode, SequentialInfluenceCalculator
from model_train import train_with_fea_val, MLP, get_val_loss


device = "cuda:0" if torch.cuda.is_available() else "cpu"


def optimize_img_embedding_by_influence(cla_model, clip_model, text_generator, loss, train_fea, train_lab,
                                        val_x, val_y, desired_class, desired_class_name, negative_class_name,
                                        img_origin, step_size=1e-6, optimize_iterations=1000, train_lr=1e-2):
    """
    Optimize image embedding by influence.

    Args:
        cla_model: The trained model.
        loss: The loss function.
        train_data_loader: The data loader for the training data.
        x_test: The test input data.
        y_test: The test target data.
        x: The input data for which the embedding needs to be optimized.
        y: The target data for which the embedding needs to be optimized.
        step_size: The step size for optimization (default: 1e-2).

    Returns:
        influences: A tensor representing the influence of each feature of each training point on each test point.
    """
    lb, ub = get_clip_range(None)
    if val_x.shape[0] > 3000:
        lazy_calc = True
    else:
        lazy_calc = False
    # calc influence and update embedding
    cla_model = replace_all_batch_norm_modules_(cla_model)
    inf_device = "cuda:0"
    infl_model = ArnoldiInfluence(cla_model.to(inf_device), loss.to(inf_device))
    # infl_model = NystroemSketchInfluence(cla_model.to(inf_device), loss.to(inf_device), hessian_regularization=0.01, rank=100)
    # infl_model = LissaInfluence(cla_model.to(inf_device), loss.to(inf_device), progress=True)
    # infl_model = CgInfluence(model, loss, rtol=1e-4, atol=1e-4, maxiter=1000, progress=True)
    # infl_model = DirectInfluence(cla_model.to(inf_device), loss.to(inf_device))
    train_tensor_loader = DataLoader(TensorDataset(train_fea, train_lab), batch_size=32)
    infl_model = infl_model.fit(train_tensor_loader)
    infl_calc = SequentialInfluenceCalculator(infl_model) if lazy_calc else None
    img_optimized = img_origin.clone()

    best_improving = 0
    ret_img_emb = img_optimized
    best_iter = 0
    origin_img_txt_desc = text_generator.get_text_description(img_clip_fea=img_optimized.to(device))
    ret_dict = dict()
    text = clip.tokenize([f"a photo of a {desired_class_name}", f"this picture associates with {desired_class_name}", 
                          f"a {desired_class_name}"] + [f"a photo of {na}" for na in negative_class_name]).to(device)
    
    # get initial loss and accuracy
    clone_cla_model = MLP(in_num=cla_model.in_fea_dim, hidden_num=128, out_num=cla_model.out_fea_dim).to(device)
    ret_tuple = train_with_fea_val(clone_cla_model, train_fea, train_lab, val_x, val_y, train_lr, max_train_epochs=300)
    clone_cla_model = ret_tuple[0]
    ini_val_loss, ini_val_acc = get_val_loss(clone_cla_model.to(device), val_x.to(device), val_y.to(device))
    with torch.no_grad():
        text_features = clip_model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    for _ in range(optimize_iterations):
        if lazy_calc:
            img_optimized_loader = DataLoader(TensorDataset(img_optimized.unsqueeze(0), torch.as_tensor([desired_class])), batch_size=1)
            val_data_loader = DataLoader(TensorDataset(val_x, val_y), batch_size=128)
            influences = infl_calc.influences(val_data_loader, img_optimized_loader, mode=InfluenceMode.Perturbation)
            influences = influences.compute(aggregator=NestedTorchCatAggregator())
        else:
            influences = infl_model.influences(val_x, val_y, img_optimized.unsqueeze(0), torch.as_tensor([desired_class]),
                                            mode=InfluenceMode.Perturbation)
        img_optimized = update_img_embedding(img_optimized.to(device), step_size,
                                             torch.mean(influences, dim=0).squeeze(), lb, ub)
        if _ % 10 == 0:
            clone_cla_model = MLP(in_num=cla_model.in_fea_dim, hidden_num=128, out_num=cla_model.out_fea_dim).to(device)
            repeat_num = min(train_fea.shape[0]//100, 10)
            ret_tuple = train_with_fea_val(clone_cla_model, torch.cat([train_fea, img_optimized.unsqueeze(dim=0).repeat(repeat_num, 1)], dim=0),
                                           torch.cat([train_lab, torch.as_tensor([desired_class]).repeat(1, repeat_num).squeeze().to(device)]), 
                                           val_x, val_y, train_lr, max_train_epochs=300)
            loss_improved = ini_val_loss - ret_tuple[1]
            acc_improved = ini_val_acc - ret_tuple[2]
            if (loss_improved > 0 or acc_improved > 0) and _ > 0:
                gen_txt = text_generator.get_text_description(img_clip_fea=img_optimized.to(device))
                gen_txt = f"This picture associated with {desired_class_name}. Specifically, it {'is' if 'image' in gen_txt or 'photo' in gen_txt else 'illustrates'} " + gen_txt
                if best_improving < loss_improved:
                    best_improving = loss_improved
                    best_iter = _
                    ret_img_emb = img_optimized.cpu().clone()
                    ret_dict[gen_txt] = (ret_img_emb, best_improving, _)
            print(f"iter: {_}; loss improved: {loss_improved}")
    print(f"best iter: {best_iter}; best loss improved: {best_improving}")
    return ret_dict


def get_sim_between_img_txt(clip_model, image_features, text_descriptions):
    text_inputs = clip.tokenize(text_descriptions).to(device)

    clip_model.eval()
    # Calculate features
    with torch.no_grad():
        text_features = clip_model.encode_text(text_inputs)

    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = torch.as_tensor(image_features).to(device) @ text_features.T
    return similarity


def get_clip_range(X_orig):
    box_radius_in_float = 0.5 * 2.0 / 255.0
    if X_orig is None:
        lower_bound = -1
        upper_bound = 1
    else:
        lower_bound = np.maximum(
            -torch.ones_like(X_orig),
            X_orig - box_radius_in_float)
        upper_bound = np.minimum(
            torch.ones_like(X_orig),
            X_orig + box_radius_in_float)
    return lower_bound, upper_bound


def update_img_embedding(emb, step_size, grad_influence_wrt_input_val_subset, lb=-1, ub=1):
    # emb = emb + step_size * torch.sign(grad_influence_wrt_input_val_subset)
    emb = emb + step_size * grad_influence_wrt_input_val_subset
    emb = torch.clip(emb, lb, ub)
    emb /= emb.norm(p=2)
    return emb
