import torch
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
from .utils import group_images, flatten_images
# from datasets import Cifar10
from transformers import Idefics3Processor, Idefics3ForConditionalGeneration, AutoProcessor, AutoModelForCausalLM
# from all_rules import rules
from .utils import load_image, count_rules
import numpy as np
from fuzzywuzzy import fuzz

import warnings
warnings.filterwarnings("ignore")


class RiskAnalysis:

    with torch.inference_mode():
        def get_activate_vector(self, **kwargs):
            pass
    with torch.inference_mode():
        def get_mu(self, **kwargs):
            pass

    def train_lrm(self, **kwargs):
        pass

class SPRA(RiskAnalysis):
    def __init__(self, vlm=None, vlm_processor=None, llm=None, llm_processor=None, lrm=None, other_vlm=None, other_processor=None):
        self.vlm_processor = vlm_processor
        self.vlm = vlm
        self.llm_processor = llm_processor
        self.llm = llm
        self.other_vlm = other_vlm
        self.other_processor = other_processor


    with torch.inference_mode():
        def get_activate_vector(self, original_data: list, activate_path='', rules=None, cls_shot=None, dataset=None, reference=[], n_rules=29):
            """
            :param original_data: data from json
            :param activate_path: The path for saving activate vector.
            :return:
            """

            activate_vector = []

            if not reference:
                with tqdm(original_data, desc='Computing:') as dbar:
                    for d in dbar:
                        image = [d['images']]  # must list like [image1, image2]
                        lang_x = d['prompt']
                        label = [d['label']]

                        with torch.inference_mode():
                            batch_activate_vector = rules(image=image,
                                  lang_x=lang_x,
                                  label=label,
                                  vlm_processor=self.vlm_processor,
                                  llm_processor=self.llm_processor,
                                  vlm=self.vlm,
                                  llm=self.llm,
                                  original_data=original_data,
                                  cls_shot=cls_shot,
                                  predcted_cls=d['response'],
                                  dataset=dataset,
                                  other_processor=self.other_processor,
                                  other_vlm=self.other_vlm
                                  )  # shape of (B, n).
                            activate_vector.append(batch_activate_vector)

                activate_vector = torch.cat(activate_vector, dim=0)
                torch.save(activate_vector.cpu(), activate_path)
            else:
                with tqdm(original_data, desc='Computing:') as dbar:
                    for idx, (d, rd) in enumerate(zip(dbar, reference)):
                        if d['response'] == rd['response']:
                            batch_activate_vector = torch.ones(1, n_rules, device=self.vlm.device) * -1
                            activate_vector.append(batch_activate_vector)
                            continue

                        image = [d['images']]  # must list like [image1, image2]
                        lang_x = d['prompt']

                        label = [d['label']]

                        with torch.inference_mode():
                            batch_activate_vector = rules(image=image,
                                  lang_x=lang_x,
                                  label=label,
                                  vlm_processor=self.vlm_processor,
                                  llm_processor=self.llm_processor,
                                  vlm=self.vlm,
                                  llm=self.llm,
                                  original_data=original_data,
                                  cls_shot=cls_shot,
                                  predcted_cls=d['response'],
                                  dataset=dataset,
                                  other_processor=self.other_processor,
                                  other_vlm=self.other_vlm
                                  )  # shape of (B, n).
                            activate_vector.append(batch_activate_vector)

                activate_vector = torch.cat(activate_vector, dim=0)
                torch.save(activate_vector.cpu(), activate_path)

            return

    with torch.inference_mode():
        def get_mu(self, original_data, batch_size, activate_vectors, mu_path='', id=''):
            if not os.path.exists(mu_path):
                os.makedirs(mu_path)

            u = []

            for i in range(activate_vectors.shape[-1]):
                a = activate_vectors[:, i]

                alpha_beta = 0.
                alpha = 0.

                with tqdm(range(0, len(original_data), batch_size), desc="Computing") as progress_bar:
                    for j in progress_bar:
                        batch = original_data[j:j + batch_size]
                        batch_a = a[j:j + batch_size].numpy()

                        pre = [b['response']for b in batch]
                        label = [b['label']for b in batch]

                        pre = np.array(pre)
                        label = np.array(label)

                        pre = pre[batch_a.astype(int)]
                        label = label[batch_a.astype(int)]

                        alpha_beta += len(batch_a)
                        # for p, l in zip(pre, label):
                        #     if fuzz.ratio(p.lower(), l) > 75:
                        #         alpha += 1

                        alpha += sum(pre == label)

                        progress_bar.set_postfix(mu=(alpha / (alpha_beta + 1e-6)))

                u.append(torch.tensor(alpha / (alpha_beta + 1e-6)))

            torch.save(torch.stack(u, dim=0), os.path.join(mu_path, f'{id}_mu.pth'))
            return



    with torch.inference_mode():
        def evalu_generate_risk_label(self, original_data, batch_size, risk_label_path='', save=True, id=''):

            alpha = 0.
            alpha_beta = 0.
            risk_label = []
            with tqdm(range(0, len(original_data), batch_size), desc="Computing") as progress_bar:
                for j in progress_bar:
                    batch = original_data[j:j + batch_size]

                    pre = [b['response'] for b in batch]
                    label = [b['label'] for b in batch]

                    pre = np.array(pre)
                    label = np.array(label)

                    alpha_beta += batch_size
                    alpha += sum(pre == label)

                    r = torch.from_numpy(pre != label)
                    risk_label.append(r)

                    progress_bar.set_postfix(mu_1=f'{alpha / (alpha_beta + 1e-6):.2f}')

            if save:
                if not os.path.exists(risk_label_path):
                    os.makedirs(risk_label_path)

                torch.save(torch.cat(risk_label, dim=0), os.path.join(risk_label_path, f'{id}_risk_label.pth'))
