from thop import profile
import os
import json
from torch.utils.data import TensorDataset
import numpy as np
import torch
from loguru import logger
from PIL import Image


class CusDataset(TensorDataset):
    def __init__(self, data, transform=None):
        assert "x" in data
        assert "y" in data
        self.data = {}
        self.data["x"] = (data["x"])
        self.data["y"] = (data["y"])
        '''while len(self.data["x"])<4:
            if len(self.data["x"])==0:
                break
            self.data["x"].extend(data["x"])
            self.data["y"].extend(data["y"])
            #print(len(self.data["x"]))
            #exit()'''
        self.transform = transform
        '''if torch.cuda.device_count()>0:
            self.data["x"]=self.data["x"].cuda()
            self.data["y"]=self.data["y"].cuda()'''
        self.data1 = None
        self.data2 = None

    def __getitem__(self, item):
        if self.transform is None:
            ret = torch.tensor(self.data['x'][item]).cuda()
        else:
            # print(self.data['x'][item].shape)

            ret = np.array(self.data["x"][item]).astype("uint8")
            # print(ret.shape)
            if ret.shape[-1] == 3:
                ret = ret
            elif ret.shape[0] == 3:
                #assert ret.shape[0]==3
                ret = ret.transpose(1, 2, 0)
            else:
                ret = ret
            # print(ret.shape)
            # print(ret)
            #print(Image.fromarray(ret).max(), Image.fromarray(ret).min())
            ret = self.transform(Image.fromarray(ret)).cuda()
            #print(abs(ret).max(), abs(ret).min())
            # print(ret.shape)
        if self.data1 is None:
            return [ret, torch.tensor(self.data["y"][item]).cuda()]
        else:
            x1, _, y1 = self.data1[item]
            x2, y2, _ = self.data2[item]
            return [ret, torch.tensor(self.data["y"][item]).cuda()], [x1[0], y1[0]], [x2[0], y2[0]]

    def __len__(self):
        return len(self.data["x"])

    def set_psuedo_data(self, data1, data2):
        self.data1 = data1
        self.data2 = data2

    def get_total_data(self):
        ret = []
        self.data1 = None
        self.data2 = None
        for i in range(len(self.data["x"])):
            ret.append(self.__getitem__(i))
        return ret


def Flops(model, inp):
    return profile(model, inputs=(inp,), verbose=False)[0]


def read_data(train_data_path, test_data_path):
    groups = []
    train_data = {}
    test_data = {}
    train_files = os.listdir(train_data_path)
    train_files = [f for f in train_files if f.endswith(".json")]
    for f in train_files:
        file_path = os.path.join(train_data_path, f)
        with open(file_path, "r") as inf:
            cdata = json.load(inf)
        if "hierarchies" in cdata:
            groups.extend(cdata["hierarchies"])
        train_data.update(cdata["user_data"])

    test_files = os.listdir(test_data_path)
    test_files = [f for f in test_files if f.endswith(".json")]
    for f in test_files:
        file_path = os.path.join(test_data_path, f)
        with open(file_path, "r") as inf:
            cdata = json.load(inf)
        test_data.update(cdata["user_data"])
    clients = list(sorted(train_data.keys()))
    return clients, groups, train_data, test_data


def decode_stat(stat):
    if len(stat) == 4:
        ids, groups, num_samples, tot_correct = stat
        logger.info("Accuracy: {}".format(sum(tot_correct) / sum(num_samples)))
    elif len(stat) == 5:
        ids, groups, num_samples, tot_correct, losses = stat
        logger.info("Accuracy: {} Loss: {}".format(sum(tot_correct) / sum(num_samples), sum(losses) / sum(num_samples)))
    else:
        raise ValueError
