import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import Sampler
import numpy as np
import scipy as sp
import scipy.stats
import random
import scipy.io as sio
from sklearn import preprocessing
import torch.nn as nn
import torch.nn.functional as F
from augment import CenterResizeCrop


class Task(object):

    def __init__(self, data1, data2, num_classes, support_ratio, query_ratio):
        self.data1 = data1
        self.data2 = data2
        self.num_classes = num_classes
        self.support_ratio = support_ratio
        self.query_ratio = query_ratio

        class_folders = sorted(list(data1))

        class_list = random.sample(class_folders, self.num_classes)

        labels = np.sort(class_list)
        # labels = np.array(range(len(class_list)))
        labels = dict(zip(class_list, labels))

        samples = dict()

        self.support_datas1 = []
        self.query_datas1 = []
        self.support_datas2 = []
        self.query_datas2 = []
        self.support_labels = []
        self.query_labels = []

        for c in class_list:
            temp1 = self.data1[c]  # list from data1
            temp2 = self.data2[c]  # list from data2
            samples[c] = list(zip(temp1, temp2))  # Combine samples pairwise

            random.shuffle(samples[c])

            total_samples_0 = len(self.data1[c])
            support_num = int(total_samples_0 * self.support_ratio)
            query_num = int(total_samples_0 * self.query_ratio)


            self.support_datas1 += [sample_pair[0] for sample_pair in samples[c][:support_num]]
            self.query_datas1 += [sample_pair[0] for sample_pair in samples[c][support_num:support_num + query_num]]

            self.support_datas2 += [sample_pair[1] for sample_pair in samples[c][:support_num]]
            self.query_datas2 += [sample_pair[1] for sample_pair in samples[c][support_num:support_num + query_num]]

            self.support_labels += [labels[c] for _ in range(support_num)]
            self.query_labels += [labels[c] for _ in range(query_num)]
        self.support_num = support_num
        self.query_num = query_num


class FewShotDataset(Dataset):
    def __init__(self, task, transformer = None, split='train'):
        self.task = task
        self.split = split
        self.transformer = transformer
        self.data = self.task.support_datas1 if self.split == 'train' else self.task.query_datas1
        self.data_LIDAR = self.task.support_datas2 if self.split == 'train' else self.task.query_datas2
        self.labels = self.task.support_labels if self.split == 'train' else self.task.query_labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        raise NotImplementedError("This is an abstract class. Subclass this class for your particular dataset.")


class HBKC_dataset(FewShotDataset):
    def __init__(self, *args, **kwargs):
        super(HBKC_dataset, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        label = self.labels[index]
        if self.transformer == None:
            img = torch.from_numpy(np.asarray(self.data[index]))
            img_LIDAR = torch.from_numpy(np.asarray(self.data_LIDAR[index]))
            return img, img_LIDAR, label
        elif len(self.transformer) == 2:
            img = torch.from_numpy(np.asarray(self.transformer[1](self.transformer[0](self.data[index]))))
            img_LIDAR = torch.from_numpy(np.asarray(self.transformer[1](self.transformer[0](self.data_LIDAR[index]))))
            return img, img_LIDAR, label
        else:
            img = torch.from_numpy(np.asarray(self.transformer[0](self.data[index])))
            img_LIDAR = torch.from_numpy(np.asarray(self.transformer[0](self.data_LIDAR[index])))
        # img = self.image_datas[index]
        # img_LIDAR = self.image_datas2[index]
        # label = self.labels[index]
        return img, img_LIDAR, label


# Sampler
class ClassBalancedSampler(Sampler):
    ''' Samples 'num_inst' examples each from 'num_cl' pool of examples of size 'num_per_class' '''
    # 参数：
    #   num_per_class: 每个类的样本数量
    #   num_cl: 类别数量
    #   num_inst：support set或query set中的样本数量
    #   shuffle：样本是否乱序
    def __init__(self, num_per_class, num_cl, num_inst,shuffle=True):
        self.num_per_class = num_per_class
        self.num_cl = num_cl
        self.num_inst = num_inst
        self.shuffle = shuffle

    def __iter__(self):
        # return a single list of indices, assuming that items will be grouped by class
        if self.shuffle:

            batch = [[i+j*self.num_per_class for i in torch.randperm(self.num_per_class)[:self.num_per_class]] for j in range(self.num_cl)]
        else:
            batch = [[i+j*self.num_per_class for i in range(self.num_per_class)[:self.num_per_class]] for j in range(self.num_cl)]
        batch = [item for sublist in batch for item in sublist]

        if self.shuffle:
            random.shuffle(batch)
        return iter(batch)

    def __len__(self):
        return 1

class ClipLoss(torch.nn.Module):

    def __init__(self, logit_scale=None):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))  #####
        # self.logit_scale = logit_scale

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        labels = torch.arange(num_logits, device=device, dtype=torch.long)
        return labels

    def get_logits(self, image_features, text_features):
        # normalized features
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        logit_scale = self.logit_scale.exp()

        ## cosine similarity as logits
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text

    def forward(self, image_features, text_features):
        image_features = F.normalize(image_features, 2, dim=1)
        text_features = F.normalize(text_features, 2, dim=1)
        device = image_features.device
        logits_per_image, logits_per_text = self.get_logits(image_features, text_features)

        labels = self.get_ground_truth(device, logits_per_image.shape[0])

        total_loss = (F.cross_entropy(logits_per_image, labels) +
                      F.cross_entropy(logits_per_text, labels)) / 2

        return total_loss


class Queue:
    def __init__(self, capacity, dim):
        self.capacity = capacity
        self.size = 0
        self.queue = torch.empty((capacity, dim), dtype=torch.float32)
        self.front = 0
        self.rear = -1

    def is_full(self):
        return self.size == self.capacity

    def is_empty(self):
        return self.size == 0

    def enqueue(self, item):
        if self.is_full():
            print("Queue is full")
            return
        self.rear = (self.rear + 1) % self.capacity
        self.queue[self.rear] = item
        self.size += 1

    def dequeue(self):
        if self.is_empty():
            print("Queue is empty")
            return
        item = self.queue[self.front]
        self.front = (self.front + 1) % self.capacity
        self.size -= 1
        return item

    def peek(self):
        if self.is_empty():
            print("Queue is empty")
            return
        return self.queue[self.front]

    def display(self):
        if self.is_empty():
            print("Queue is empty")
            return
        if self.front <= self.rear:
            print(self.queue[self.front:self.rear+1])
        else:
            print(torch.cat((self.queue[self.front:], self.queue[:self.rear+1])))

    def get_queue(self):
        return self.queue

    def get_num(self):
        return self.size


def get_loss_clip(capacity, dim_feature, label, features, text_features):
    loss_clip_all = 0
    Loss = ClipLoss(logit_scale=5)
    queue_0 = Queue(capacity=capacity, dim=dim_feature)
    queue_1 = Queue(capacity=capacity, dim=dim_feature)
    queue_2 = Queue(capacity=capacity, dim=dim_feature)
    queue_3 = Queue(capacity=capacity, dim=dim_feature)
    queue_4 = Queue(capacity=capacity, dim=dim_feature)
    queue_5 = Queue(capacity=capacity, dim=dim_feature)
    queue_6 = Queue(capacity=capacity, dim=dim_feature)
    queue_7 = Queue(capacity=capacity, dim=dim_feature)
    queue_8 = Queue(capacity=capacity, dim=dim_feature)
    queue_9 = Queue(capacity=capacity, dim=dim_feature)
    queue_10 = Queue(capacity=capacity, dim=dim_feature)
    queue_11 = Queue(capacity=capacity, dim=dim_feature)
    queue_12 = Queue(capacity=capacity, dim=dim_feature)
    queue_13 = Queue(capacity=capacity, dim=dim_feature)
    queue_14 = Queue(capacity=capacity, dim=dim_feature)
    for n, vaule in enumerate(label):
        if (vaule == 0):
            queue_0.enqueue(features[n, :])  # 还是假定每个块里各个类别是相同数量的
        elif (vaule == 1):
            queue_1.enqueue(features[n, :])
        elif (vaule == 2):
            queue_2.enqueue(features[n, :])
        elif (vaule == 3):
            queue_3.enqueue(features[n, :])
        elif (vaule == 4):
            queue_4.enqueue(features[n, :])
        elif (vaule == 5):
            queue_5.enqueue(features[n, :])
        elif (vaule == 6):
            queue_6.enqueue(features[n, :])
        elif (vaule == 7):
            queue_7.enqueue(features[n, :])
        elif (vaule == 8):
            queue_8.enqueue(features[n, :])
        elif (vaule == 9):
            queue_9.enqueue(features[n, :])
        elif (vaule == 10):
            queue_10.enqueue(features[n, :])
        elif (vaule == 11):
            queue_11.enqueue(features[n, :])
        elif (vaule == 12):
            queue_12.enqueue(features[n, :])
        elif (vaule == 13):
            queue_13.enqueue(features[n, :])
        elif (vaule == 14):
            queue_14.enqueue(features[n, :])
    for n in range(capacity):
        temp = torch.empty((1, dim_feature))
        temp = torch.unsqueeze(queue_0.dequeue(), dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_1.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_2.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_3.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_4.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_5.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_6.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_7.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_8.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_9.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_10.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_11.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_12.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_13.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_14.dequeue(), dim=0)], dim=0).cuda(0)
        loss_clip_all += Loss(image_features=temp, text_features=text_features)

    return loss_clip_all


def js_div(p_output, q_output, get_softmax=True):
    """
    Function that measures JS divergence between target and output logits:
    """
    p_output = F.normalize(p_output, p=2, dim=1)
    q_output = F.normalize(q_output, p=2, dim=1)
    KLDivLoss = nn.KLDivLoss(reduction='batchmean')
    if get_softmax:
        p_output = F.softmax(p_output, dim=-1)
        q_output = F.softmax(q_output, dim=-1)
    log_mean_output = ((p_output + q_output )/2).log()
    return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2

def get_loss_clip_trento(capacity, dim_feature, label, features, text_features):
    loss_clip_all = 0
    Loss = ClipLoss(logit_scale=5)
    queue_0 = Queue(capacity=capacity, dim=dim_feature)
    queue_1 = Queue(capacity=capacity, dim=dim_feature)
    queue_2 = Queue(capacity=capacity, dim=dim_feature)
    queue_3 = Queue(capacity=capacity, dim=dim_feature)
    queue_4 = Queue(capacity=capacity, dim=dim_feature)
    queue_5 = Queue(capacity=capacity, dim=dim_feature)
    for n, vaule in enumerate(label):
        if (vaule == 0):
            queue_0.enqueue(features[n, :])  # 还是假定每个块里各个类别是相同数量的
        elif (vaule == 1):
            queue_1.enqueue(features[n, :])
        elif (vaule == 2):
            queue_2.enqueue(features[n, :])
        elif (vaule == 3):
            queue_3.enqueue(features[n, :])
        elif (vaule == 4):
            queue_4.enqueue(features[n, :])
        elif (vaule == 5):
            queue_5.enqueue(features[n, :])
    for n in range(capacity):
        temp = torch.empty((1, dim_feature))
        temp = torch.unsqueeze(queue_0.dequeue(), dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_1.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_2.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_3.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_4.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_5.dequeue(), dim=0)], dim=0).cuda(0)
        loss_clip_all += Loss(image_features=temp, text_features=text_features)

    return loss_clip_all


def get_loss_clip_augsburg(capacity, dim_feature, label, features, text_features):
    loss_clip_all = 0
    Loss = ClipLoss(logit_scale=5)
    queue_0 = Queue(capacity=capacity, dim=dim_feature)
    queue_1 = Queue(capacity=capacity, dim=dim_feature)
    queue_2 = Queue(capacity=capacity, dim=dim_feature)
    queue_3 = Queue(capacity=capacity, dim=dim_feature)
    queue_4 = Queue(capacity=capacity, dim=dim_feature)
    queue_5 = Queue(capacity=capacity, dim=dim_feature)
    queue_6 = Queue(capacity=capacity, dim=dim_feature)
    for n, vaule in enumerate(label):
        if (vaule == 0):
            queue_0.enqueue(features[n, :])  # 还是假定每个块里各个类别是相同数量的
        elif (vaule == 1):
            queue_1.enqueue(features[n, :])
        elif (vaule == 2):
            queue_2.enqueue(features[n, :])
        elif (vaule == 3):
            queue_3.enqueue(features[n, :])
        elif (vaule == 4):
            queue_4.enqueue(features[n, :])
        elif (vaule == 5):
            queue_5.enqueue(features[n, :])
        elif (vaule == 6):
            queue_6.enqueue(features[n, :])
    for n in range(capacity):
        temp = torch.empty((1, dim_feature))
        temp = torch.unsqueeze(queue_0.dequeue(), dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_1.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_2.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_3.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_4.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_5.dequeue(), dim=0)], dim=0)
        temp = torch.cat([temp, torch.unsqueeze(queue_6.dequeue(), dim=0)], dim=0).cuda(0)
        loss_clip_all += Loss(image_features=temp, text_features=text_features)

    return loss_clip_all