import torch
from hashlib import sha1
import numpy as np
import pyhash
from imagehash import phash
from PIL import Image
from typing import List, Tuple, Union
import torch.nn as nn
import random
from ml_common import get_device
from torch.nn import Module
from collections import Counter
import torch.nn.functional as F

cosine_similarity = torch.nn.CosineSimilarity(dim=-1)


class D3Model(object):
    def __init__(
        self,
        model_list: List[nn.Module],
        bounds: Tuple[float, float],
        num_classes: int = 10,
        quantization: bool = False,
        hash_size=4,
        model_hash=None,
        hash_mode="phash",
    ):
        self.device = get_device()
        self.model_list = model_list
        self.n_models = len(self.model_list)
        for model in self.model_list:
            model = model.to(self.device)
            model.eval()
        self.bounds = bounds
        self.num_classes = num_classes
        self.n_queries = 0
        self.quantization = quantization
        self.hash_size = hash_size
        self.model_hash = model_hash
        self.hash_mode = hash_mode
        self.hash_list = []

        if self.model_hash is not None:
            self.model_hash = self.model_hash.to(self.device)

        if self.hash_mode == "dnn":
            assert self.model_hash is not None

    def eval(self):
        for model in self.model_list:
            model.eval()

    def coherence(self, x):
        pred_list = []
        cs_list = []
        for model in self.model_list:
            pred = F.softmax(model(x), dim=-1)
            pred_list.append(pred)
        pred_list_batch = torch.stack(pred_list, dim=0)  # n x batch x 10

        for i, pred_i in enumerate(pred_list_batch):
            for j in range(i + 1, len(self.model_list)):
                pred_j = pred_list_batch[j]
                cs = cosine_similarity(pred_i, pred_j)  # batch x 10
                cs_list.append(cs.detach())

        cs_batch = torch.max(torch.stack(cs_list, dim=0), dim=0)[0]
        return cs_batch

    def to(self, device: str):
        for model in self.model_list:
            model = model.to(device)
        return self

    def quantize(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.round(x * 128.0) / 128.0
        return x

    def clamp(self, x: torch.Tensor) -> torch.Tensor:
        """ clamp """
        x = torch.clamp(x, self.bounds[0], self.bounds[1])
        return x

    def get_hash_list(self, x: torch.Tensor) -> List[int]:
        if self.hash_mode == "phash":
            return self.get_phash_list(x)
        elif self.hash_mode == "dnn":
            return self.get_dnnhash_list(x)
        elif self.hash_mode == "sha1":
            return self.get_sha1hash_list(x)
        else:
            raise ValueError("Unknown hash_mode")

    def get_sha1hash_list(self, x: torch.Tensor) -> List[int]:
        indices = []
        for xi in x:
            index = int(sha1(xi.cpu().numpy()).hexdigest(), 16) % self.n_models
            indices.append(index)

        return indices

    def get_dnnhash_list(self, x: torch.Tensor) -> List[int]:
        y = F.softmax(self.model_hash(x), dim=-1)
        y_class = torch.argmax(y, dim=-1)
        M = 10
        m = 2
        y_hash = (
            (m * y_class) + torch.floor((M * torch.max(y, dim=-1)[0] - 1) * m / (M - 1))
        ) % self.n_models
        return y_hash.detach().cpu().numpy().tolist()

    def get_phash_list(self, x: torch.Tensor) -> List[int]:
        """ figure out seed """
        xq = self.quantize(self.clamp(x))
        xq = 255 * (xq - self.bounds[0]) / (self.bounds[1] - self.bounds[0])

        hash_list = []
        if xq.shape[1] == 1:
            pil_mode = "L"
        else:
            pil_mode = "RGB"

        for xi in xq:
            if xq.ndim == 2:  # simple grid hash
                xi = torch.round(32 * xi / 255)
                seed = 0
                for i, xd in enumerate(xi):
                    seed += (i + 1) * xd.detach().cpu().item()
                random.seed(int(seed))
                hash = random.randrange(self.n_models)
            else:  # perceptual hash
                # convert from channel first to channel last
                if xq.shape[1] == 1:
                    pil_image = Image.fromarray(
                        xi.cpu()
                        .numpy()
                        .swapaxes(0, 2)
                        .swapaxes(1, 0)
                        .squeeze(-1)
                        .astype("uint8"),
                        pil_mode,
                    )
                else:
                    pil_image = Image.fromarray(
                        xi.cpu().numpy().swapaxes(0, 2).swapaxes(1, 0).astype("uint8"),
                        pil_mode,
                    )

                hash = phash(pil_image, self.hash_size)
                hash = int(str(hash), 16) % self.n_models
            hash_list.append(hash)
        return hash_list

    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
        """ clamp and quantize the input """
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).type(torch.FloatTensor).to(self.device)
        if len(x.size()) in [1, 3]:  # assuming a vector/image as input
            x = x.unsqueeze(0)
        x = self.clamp(x)
        if self.quantization:
            x = self.quantize(x)
        return x

    def update_stats(self, hash_list):
        self.hash_list += hash_list

    def __call__(
        self, x: torch.Tensor, label: bool = False, index=-1, x_hash=None
    ) -> torch.Tensor:
        x = self.preprocess(x)
        self.n_queries += x.shape[0]
        out_list = []
        if index == -1:
            for model in self.model_list:
                out = model(x)
                out_list.append(out)
            out_all = torch.stack(out_list, dim=0)
            if x_hash is None:
                hash_list = self.get_hash_list(x)
            else:
                hash_list = self.get_hash_list(x_hash)

            self.update_stats(hash_list)
            out = out_all[hash_list, range(x.shape[0])]
        else:
            out = self.model_list[index](x)
        if label:
            out = torch.argmax(out, dim=-1)
        return out

    def get_n_queries(self) -> int:
        return self.n_queries
