import numpy as np
from termcolor import colored
import open_clip
import torch
from sklearn.linear_model import Ridge
from sklearn.kernel_ridge import KernelRidge

from aux_func import return_pairwise_kernel, clip_ucb_bonus, RFFKernel


# Sets of T2I models for evaluation
repo_ids = np.array(["runwayml/stable-diffusion-v1-5",
                     "stabilityai/stable-diffusion-2-base",
                     "PixArt-alpha/PixArt-XL-2-512x512",
                     "DeepFloyd"])
G = 4

# Load the CLIP model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
N_feat = 512
clip_model, _, _ = open_clip.create_model_and_transforms('ViT-B-32')
clip_model.load_state_dict(torch.load(''))
clip_model.eval()
clip_model.to(device)
tform = NotImplementedError     # image transformation for CLIP
tokenizer = open_clip.get_tokenizer('ViT-B-32')
print(colored('##  Pre-trained CLIP ViT-B-32 is loaded.', 'red'))


if __name__ == '__main__':

    np.random.seed(1234)

    alg_type = 'SCK_UCB'                                                        # greedy, naive_KRR, SCK_UCB

    apply_rff = False
    if alg_type in ['naive_KRR', 'SCK_UCB']:
        gamma = 3.                                                              # parameter for kernels
        kernel_method = 'rbf'                                                   # kernels: linear/polynomial/rbf
        kernel = return_pairwise_kernel(kernel_method=kernel_method)
        if kernel_method == 'rbf':
            apply_rff = True
            N_feat = 50 if apply_rff else N_feat
            rff = RFFKernel(input_dim=512, num_features=N_feat, sigma=gamma)    # random features generator

    for epoch in range(1, 21):                                                  # averaged over 20 trials

        # Initialization
        alg_score = 0.
        regs = [[] for _ in range(G)]                                           # set of regression models
        drep = [np.empty((1, N_feat,)) for _ in range(G)]                       # CLIP-based features of prompts
        dscore = [np.empty((1,)) for _ in range(G)]                             # historical (normalized) scores
        model_pred_score = np.ones(G) * np.inf                                  # predicted scores by ALG
        model_emp_score = np.ones(G) * np.inf                                   # empirical scores
        visit = np.zeros((G,))                                                  # visitation

        # Start of evaluation
        for rd in range(1, 5001):

            prompt2gen = NotImplementedError                                    # prompt to generate
            prompt2gen_feat = clip_model.encode_text(tokenizer(prompt2gen).to(device))
            prompt2gen_feat = (
                        prompt2gen_feat / prompt2gen_feat.norm(dim=-1, keepdim=True)).detach().cpu().numpy()
            if apply_rff:
                prompt2gen_rand_feat = rff.transform(prompt2gen_feat)

            # Pick model
            if alg_type == 'greedy':
                select_model = np.argmax(model_emp_score / np.maximum(np.ones(G), visit))

            elif alg_type == 'SCK_UCB':
                delta = .05                                                     # failure probability
                alpha = 1.                                                      # regularization parameter
                eta = np.sqrt(2. * np.log(2. * G / delta))                      # exploration parameter
                if rd <= G:
                    select_model = np.argmax(model_pred_score)
                else:
                    est_score = np.empty((G,))
                    for g in range(G):
                        if apply_rff:
                            est_score[g] = regs[g].predict(prompt2gen_rand_feat)[0]   # RFF-UCB uses random features
                        else:
                            est_score[g] = regs[g].predict(prompt2gen_feat)[0]

                    if apply_rff:
                        bonus = clip_ucb_bonus(prompt2gen_feat=prompt2gen_rand_feat, drep=drep,
                                               kernel_method=kernel_method, kernel=kernel, alpha=alpha,
                                               gamma=gamma, N_feat=N_feat, apply_rff=apply_rff)
                    else:
                        bonus = clip_ucb_bonus(prompt2gen_feat=prompt2gen_feat, drep=drep, kernel_method=kernel_method,
                                               kernel=kernel, alpha=alpha, gamma=gamma,
                                               N_feat=N_feat, apply_rff=apply_rff)

                    select_model = np.argmax(100 * (est_score + eta * bonus))

            elif alg_type == 'naive_KRR':                                               # KRR without exploration
                if rd <= G:
                    select_model = np.argmax(model_pred_score)
                else:
                    est_score = np.empty((G,))
                    for g in range(G):
                        if apply_rff:
                            est_score[g] = regs[g].predict(prompt2gen_rand_feat)[0]
                        else:
                            est_score[g] = regs[g].predict(prompt2gen_feat)[0]
                    select_model = np.argmax(100. * est_score)

            else:
                raise NotImplementedError

            gen_img = NotImplementedError                                               # Generate an image
            gen_img_feat = clip_model.encode_image(gen_img)
            gen_img_feat = (gen_img_feat / gen_img_feat.norm(dim=-1, keepdim=True)).detach().numpy()
            cosine_angle = (gen_img_feat.squeeze(0)).dot(prompt2gen_feat.squeeze(0))
            score = max(0., 100. * cosine_angle)        # Compute CLIPScore for the generated images

            # Update statistics
            if visit[select_model] == 0:
                model_emp_score[select_model] = score
                model_pred_score[select_model] = 0.
            else:
                model_emp_score[select_model] += score
            alg_score += score

            # Updates estimators
            if alg_type in ['SCK_UCB', 'naive_KRR']:
                if visit[select_model] == 0:
                    if apply_rff:
                        drep[select_model][0] = prompt2gen_rand_feat
                    else:
                        drep[select_model][0] = prompt2gen_feat
                    dscore[select_model][0] = score / 100
                else:
                    if apply_rff:
                        drep[select_model] = np.concatenate((drep[select_model], prompt2gen_rand_feat), axis=0)
                    else:
                        drep[select_model] = np.concatenate((drep[select_model], prompt2gen_feat), axis=0)
                    dscore[select_model] = np.concatenate((dscore[select_model], np.array([score / 100])), axis=0)

                visit[select_model] += 1

                # Fit the regression model
                if apply_rff:
                    regs[select_model] = Ridge(alpha=alpha).fit(drep[select_model], dscore[select_model])
                else:
                    regs[select_model] = KernelRidge(kernel=kernel_method, alpha=alpha,
                                                     gamma=gamma).fit(drep[select_model], dscore[select_model])
            elif alg_type == 'greedy':
                visit[select_model] += 1
            else:
                raise NotImplementedError

            if rd % 100 == 0:
                print(colored(f'epoch: {epoch}, step: {rd}, alg.avg.score: {alg_score / rd}',
                              'red'), '\n')
