# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import clip
import torchvision.datasets as datasets
from PIL import ImageFile
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch
import sys
import os
import random
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)

ImageFile.LOAD_TRUNCATED_IMAGES = True


class ImageTextData(object):


    def __init__(self, dataset, root, preprocess, sign, prompt='a picture of a', method=None, r_p_num=3, a_p_num=3, domain=None):
        if r_p_num == 1:
            r_prompt = ['this is {} domain']
        elif r_p_num == 3:
            r_prompt = ['this is {} domain', 'a {} domain', 'the domain of {}']
        else:
            r_prompt = ['this is {} domain', 'a {} domain', 'the domain of {}']
        # r_prompt = ['this is {}']
        if root[-1] == '/':
            select_dataset = root[0:-1]
        else:
            select_dataset = root

        select_dataset = select_dataset.split('/')[-1]
        self.select_dataset = select_dataset
        print(select_dataset)
        if select_dataset == 'OfficeHome':
            self.domains = ['art', 'clipart', 'product', 'real world']
            if dataset == 'Art':
                domain = 'art'
            elif dataset == 'Clipart':
                domain = 'clipart'
            elif dataset == 'Product':
                domain = 'product'
            elif dataset == 'RealWorld':
                domain = 'real world'
        elif select_dataset == 'ModernOffice31':
            self.domains =['amazon', 'dslr', 'synthetic', 'webcam']
            if dataset == 'amazon':
                domain = 'amazon'
            elif dataset == 'dslr':
                domain = 'dslr'
            elif dataset == 'synthetic':
                domain = 'synthetic'
            elif dataset == 'webcam':
                domain = 'webcam'

        elif select_dataset == 'PACS':
            self.domains = ['cartoon', 'photo', 'sketch', 'art painting']
            if dataset == 'cartoon':
                domain = 'cartoon'
            elif dataset == 'photo':
                domain = 'photo'
            elif dataset == 'sketch':
                domain = 'sketch'
            elif dataset == 'art_painting':
                domain = 'art painting'

        dataset = os.path.join(root, dataset)
        if sign:
            data = datasets.ImageFolder(dataset, transform=self._transform)
        else:
            data = datasets.ImageFolder(dataset, transform=self._TRANSFORM)
        labels = data.classes
        self.data = data
        self.labels = labels

        if method == 'ours':
            if select_dataset == 'OfficeHome':
                self.domains = ['art', 'clipart', 'product', 'real world']
            elif select_dataset == 'OfficeCaltech':
                self.domains = ['amazon', 'dslr', 'webcam']
        if method == 'ours':
            if a_p_num == 1:
                a_prompt = ['a picture of a {}']
            elif a_p_num == 3:
                a_prompt = ['a picture of a {}', 'a photo containing {}', '{} in a photo']
            else:
                a_prompt = ['a picture of a {}', 'a photo containing {}', '{} in a photo']
            # self.labels_r = [r_prompt[random.randint(0, len(r_prompt) - 1)].format(domain[x]) for x in self.labels]
        if prompt:
            if method == 'ours':
                self.labels = [a_prompt[random.randint(0, len(a_prompt) - 1)].format(x) for x in self.labels]
            else:
                self.labels = [prompt + ' ' + x for x in self.labels]
        if sign:
            self.preprocess = self._transform
        else:
            self.preprocess = self._TRANSFORM

        self.text = clip.tokenize(self.labels)
        self.method = method
        self.r_prompt = r_prompt

        self.domain = domain

    def __getitem__(self, index):
        image, label = self.data.imgs[index]
        if self.preprocess is not None:
            image = self.preprocess(Image.open(image))
        text_enc = self.text[label]
        if self.method == 'ours':
            if self.select_dataset != 'BrainTumor' and random.randint(0, 3) == 0:
                r_text = self.r_prompt[random.randint(0, len(self.r_prompt) - 1)].format(self.domains[random.randint(0, len(self.domains) - 1)])
            else:
                r_text = self.r_prompt[random.randint(0, len(self.r_prompt) - 1)].format(self.domain)
            # r_text = self.r_prompt[random.randint(0, len(self.r_prompt) - 1)].format(self.domain)

            r_text_enc = clip.tokenize(r_text)[0]
            return image, text_enc, label, r_text_enc
        return image, text_enc, label

    def __len__(self):
        return len(self.data)

    @staticmethod
    def get_data_name_by_index(index):
        name = ImageTextData._DATA_FOLDER[index]
        name = name.replace('/', '_')
        return name

    _TRANSFORM = transforms.Compose([ #test
        transforms.Lambda(lambda x: x.convert('RGB')),
        transforms.Resize([224,224]),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    _transform = transforms.Compose( #train
        [
            transforms.Lambda(lambda x: x.convert('RGB')),
            transforms.Resize([224, 224]),
         # transforms.RandomCrop(224),
         # transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 91.6667
                              std=[0.229, 0.224, 0.225])
         ])

def get_data(data_name):
    datalist = {'OfficeHome':'OfficeHome', 'BrainTumor': 'BrainTumor',
                'RealSkin': 'RealSkin','Dermnet':'Dermnet', 'ModernOffice31': 'ModernOffice31',
                'havior':'havior', 'PACS':'PACS'}
    if datalist[data_name] not in globals():
        raise NotImplementedError("Dataset not found: {}".format(data_name))
    return globals()[datalist[data_name]]


def getfeadataloader(args, model):
    trl, val, tel, telt, tr_t = [], [], [], [], []
    trd, vad, ted, teld, tr_td = [], [], [], [], []
    for i, item in enumerate(args.domains):
        if i in args.test_envs: # target domain data
            data = ImageTextData(
                item, args.root_dir+args.dataset+'/', model.preprocess, sign=0, method=args.method, r_p_num=args.r_p_num, a_p_num=args.a_p_num)
            data2 = ImageTextData(
                item, args.root_dir + args.dataset + '/', model.preprocess, sign=1, method=args.method, r_p_num=args.r_p_num, a_p_num=args.a_p_num)
            model.setselflabel(data.labels)
            ted.append(torch.utils.data.DataLoader(
                data, batch_size=args.batch, shuffle=False))
            temp = get_data_loader(
                data2, args.batch, infinite_data_loader=True)
            telt.append(temp)
            trd.append(0)
            vad.append(0)
            tr_td.append(0)
        else:
            data = ImageTextData(item, args.root_dir+args.dataset+'/', model.preprocess, sign=1, method=args.method, r_p_num=args.r_p_num, a_p_num=args.a_p_num)
            l = len(data)
            index = np.arange(l)
            np.random.seed(args.seed)
            np.random.shuffle(index)
            l1, l2, l3 = int(l*0.8), int(l*0.1), int(l*0.1) # train/val/test
            trl.append(torch.utils.data.Subset(data, index[:l1]))
            val.append(torch.utils.data.Subset(data, index[l1:l1+l2]))
            tel.append(torch.utils.data.Subset(data, index[l1+l2:l1+l2+l3]))
            tr_t.append(torch.utils.data.Subset(data, index[:l1]))
            trd.append(get_data_loader(trl[-1], batch_size=args.batch, shuffle=True, drop_last=True, infinite_data_loader=True))
            tr_td.append(get_data_loader(tr_t[-1], batch_size=args.batch, shuffle=True, drop_last=True, infinite_data_loader=False))
            vad.append(torch.utils.data.DataLoader(
                val[-1], batch_size=args.batch, shuffle=False,drop_last=True))
            ted.append(torch.utils.data.DataLoader(
                tel[-1], batch_size=args.batch, shuffle=False,drop_last=False))
    return trd, vad, ted, telt, tr_td


def getfeadataloader_by_alpha(args, model):
    trl, val, tel, telt, tr_t = [], [], [], [], []
    trd, vad, ted, teld, tr_td = [], [], [], [], []
    for i, item in enumerate(args.domains):
        if i in args.test_envs: # target domain data
            data = ImageTextData(
                item, args.root_dir+args.dataset+'/', model.preprocess, sign=0, method=args.method, r_p_num=args.r_p_num, a_p_num=args.a_p_num)
            data2 = ImageTextData(
                item, args.root_dir + args.dataset + '/', model.preprocess, sign=1, method=args.method, r_p_num=args.r_p_num, a_p_num=args.a_p_num)
            model.setselflabel(data.labels)
            ted.append(torch.utils.data.DataLoader(
                data, batch_size=args.batch, shuffle=False))
            temp = get_data_loader(
                data2, args.batch, infinite_data_loader=True)
            telt.append(temp)
            trd.append(0)
            vad.append(0)
            tr_td.append(0)
        else:
            data = ImageTextData(item, args.root_dir+args.dataset+'/', model.preprocess, sign=1, method=args.method, r_p_num=args.r_p_num, a_p_num=args.a_p_num)
            l = len(data)
            index = np.arange(l)
            np.random.seed(args.seed)
            np.random.shuffle(index)
            proportions = np.random.dirichlet(alpha=np.ones(args.K) * args.alpha)
            proportions = (proportions * l).astype(int)
            start_ll= 0
            for i in range(args.K):
                if i > 0:
                    start_ll = proportions[i-1]
                ll = proportions[i]
                l1, l2, l3 = int(ll * 0.8), int(ll * 0.1), int(ll * 0.1)  # train/val/test
                trl.append(torch.utils.data.Subset(data, index[start_ll:start_ll+l1]))
                val.append(torch.utils.data.Subset(data, index[start_ll+l1:start_ll+l1 + l2]))
                tel.append(torch.utils.data.Subset(data, index[start_ll+l1 + l2:start_ll+l1 + l2 + l3]))
                tr_t.append(torch.utils.data.Subset(data, index[start_ll:start_ll+l1]))
                trd.append(get_data_loader(trl[-1], batch_size=args.batch, shuffle=True, drop_last=True,
                                           infinite_data_loader=True))
                tr_td.append(get_data_loader(tr_t[-1], batch_size=args.batch, shuffle=True, drop_last=True,
                                             infinite_data_loader=False))
                vad.append(torch.utils.data.DataLoader(
                    val[-1], batch_size=args.batch, shuffle=False, drop_last=True))
                ted.append(torch.utils.data.DataLoader(
                    tel[-1], batch_size=args.batch, shuffle=False, drop_last=False))

                # from collections import Counter

                # # 用于统计类别数
                # class_counts = Counter()
                # data_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(data, index[start_ll:start_ll+l1+l2+l3]), batch_size=args.batch, shuffle=False, drop_last=False)
                # source_data = iter(data_loader)
                # for _ in (range(0, len(data_loader))):
                #     image, text, labels, r_text_enc = next(source_data)
                #     labels = labels.view(-1).tolist()  # 转成 list
                #     class_counts.update(labels)
                #
                # # 转成列表，确保类别顺序
                # num_classes = max(class_counts.keys()) + 1
                # counts = [class_counts[i] if i in class_counts else 0 for i in range(num_classes)]
                #
                # # 打印结果
                # print(i)
                # print(f'{counts}')


            # l1, l2, l3 = int(l*0.8), int(l*0.1), int(l*0.1) # train/val/test
            # trl.append(torch.utils.data.Subset(data, index[:l1]))
            # val.append(torch.utils.data.Subset(data, index[l1:l1+l2]))
            # tel.append(torch.utils.data.Subset(data, index[l1+l2:l1+l2+l3]))
            # tr_t.append(torch.utils.data.Subset(data, index[:l1]))
            # # trd.append(torch.utils.data.DataLoader(
            # #     trl[-1], batch_size=args.batch, shuffle=True, drop_last=True))
            # trd.append(get_data_loader(trl[-1], batch_size=args.batch, shuffle=True, drop_last=True, infinite_data_loader=True))
            # tr_td.append(get_data_loader(tr_t[-1], batch_size=args.batch, shuffle=True, drop_last=True,
            #                            infinite_data_loader=False))
            # vad.append(torch.utils.data.DataLoader(
            #     val[-1], batch_size=args.batch, shuffle=False,drop_last=True))
            # ted.append(torch.utils.data.DataLoader(
            #     tel[-1], batch_size=args.batch, shuffle=False,drop_last=False))
    return trd, vad, ted, telt, tr_td

def BrainTumor(args, model):
    trd, vad, ted, telt, tr_td = getfeadataloader(args, model)
    return trd, vad, ted, telt, tr_td

def OfficeHome(args, model):
    trd, vad, ted, telt, tr_td = getfeadataloader(args, model)
    return trd, vad, ted, telt, tr_td

def RealSkin(args, model):
    trd, vad, ted, telt, tr_td = getfeadataloader(args, model)
    return trd, vad, ted, telt, tr_td


def havior(args, model):
    trd, vad, ted, telt, tr_td = getfeadataloader(args, model)
    return trd, vad, ted, telt, tr_td

def ModernOffice31(args, model):
    trd, vad, ted, telt, tr_td = getfeadataloader(args, model)
    return trd, vad, ted, telt, tr_td

def Dermnet(args, model):
    trd, vad, ted, telt, tr_td = getfeadataloader(args, model)
    return trd, vad, ted, telt, tr_td

def PACS(args, model):
    trd, vad, ted, telt, tr_td = getfeadataloader(args, model)
    return trd, vad, ted, telt, tr_td

def PACS_by_alpha(args, model):
    trd, vad, ted, telt, tr_td = getfeadataloader(args, model)
    return trd, vad, ted, telt, tr_td
    pass

def get_data_loader(dataset, batch_size, shuffle=True, drop_last=True, infinite_data_loader=False,
                    **kwargs):
    if not infinite_data_loader:
        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last,
                                            **kwargs)
    else:
        return InfiniteDataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last,
                                   **kwargs)


class _InfiniteSampler(torch.utils.data.Sampler):
    """Wraps another Sampler to yield an infinite stream."""

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            for batch in self.sampler:
                yield batch


class InfiniteDataLoader:
    def __init__(self, dataset, batch_size, shuffle=True, drop_last=False, num_workers=0, weights=None, **kwargs):
        if weights is not None:
            sampler = torch.utils.data.WeightedRandomSampler(weights,
                                                             replacement=False,
                                                             num_samples=batch_size)
        else:
            sampler = torch.utils.data.RandomSampler(dataset,
                                                     replacement=False)

        batch_sampler = torch.utils.data.BatchSampler(
            sampler,
            batch_size=batch_size,
            drop_last=drop_last)

        self._infinite_iterator = iter(torch.utils.data.DataLoader(
            dataset,
            num_workers=num_workers,
            batch_sampler=_InfiniteSampler(batch_sampler)
        ))

    def __iter__(self):
        while True:
            yield next(self._infinite_iterator)

    def __len__(self):
        return 0  # Always return 0