import os
import torch
from torch.utils.data import ConcatDataset

from data.meta_dataset import MetaDataset, GetDataLoaderDict
from configs.default import pacs_path
from torchvision import transforms

transform_train = transforms.Compose(
            [transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            # transforms.RandomGrayscale( 0.1),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
            transforms.RandomGrayscale(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

transform_test = transforms.Compose(
            [transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

pacs_name_dict = {
    'p': 'photo',
    'a': 'art_painting',
    'c': 'cartoon',
    's': 'sketch',
}

split_dict = {
    'train': 'train',
    'val': 'crossval',
    'total': 'test',
}


def denormalization(imgs):
    device = imgs.device
    imgs = imgs * torch.tensor([0.229, 0.224, 0.225], device=device).view(3,1,1) + torch.tensor([0.485, 0.456, 0.406], device=device).view(3,1,1)
    return imgs


def get_transforms():
    normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    base_aug = [
        transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
        transforms.RandomGrayscale(), normalizer]
    global transform_train, transform_test
    transform_train = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), transforms.ToTensor()])
    transform_test = transforms.Compose([transforms.Resize([224, 224]), transforms.ToTensor(), normalizer])
    return base_aug, normalizer


def ra_transforms():
    global transform_train
    transform_train.transforms.insert(1, transforms.RandAugment())


def base_transforms(aug='standard'):
    if aug == 'aa':
        transform_train.transforms.insert(1, transforms.AutoAugment())
    elif aug == 'ra':
        transform_train.transforms.insert(1, transforms.RandAugment())
    elif aug == 'cutout':
        normalize_index = next(
            (i for i, t in enumerate(transform_train.transforms) if isinstance(t, transforms.Normalize)), None)
        if normalize_index is not None:
            transform_train.transforms.insert(normalize_index, transforms.RandomErasing())
        else:
            transform_train.transforms.append(transforms.RandomErasing())


class PACS_SingleDomain():
    def __init__(self, root_path=pacs_path, domain_name='p', split='total', train_transform=None):
        if domain_name in pacs_name_dict.keys():
            self.domain_name = pacs_name_dict[domain_name]
            self.domain_label = list(pacs_name_dict.keys()).index(domain_name)
        else:
            raise ValueError('domain_name should be in p a c s')
        
        self.root_path = os.path.join(root_path, 'raw_images')
        self.split = split
        self.split_file = os.path.join(root_path, 'raw_images', 'Train val splits and h5py files pre-read', f'{self.domain_name}_{split_dict[self.split]}_kfold' + '.txt')
        
        if train_transform is not None:
            self.transform = train_transform
        else:
            self.transform = transform_test
                
        imgs, labels = PACS_SingleDomain.read_txt(self.split_file, self.root_path)
        self.dataset = MetaDataset(imgs, labels, self.domain_label, self.transform)
        
    @staticmethod
    def read_txt(txt_path, root_path):
        imgs = []
        labels = []
        with open(txt_path, 'r') as f:
            txt_component = f.readlines()
        for line_txt in txt_component:
            line_txt = line_txt.replace('\n', '')
            line_txt = line_txt.split(' ')
            imgs.append(os.path.join(root_path, line_txt[0]))
            labels.append(int(line_txt[1]) - 1)
        return imgs, labels
    
class PACS_FedDG():
    def __init__(self, test_domain='p', batch_size=16):
        self.batch_size = batch_size
        self.domain_list = list(pacs_name_dict.keys())
        self.test_domain = test_domain
        self.train_domain_list = self.domain_list.copy()
        self.train_domain_list.remove(self.test_domain)  
        
        self.site_dataset_dict = {}
        self.site_dataloader_dict = {}
        for domain_name in self.domain_list:
            self.site_dataloader_dict[domain_name], self.site_dataset_dict[domain_name] = PACS_FedDG.SingleSite(domain_name, self.batch_size)

        self.val_dataset = ConcatDataset([self.site_dataset_dict[i]['val'] for i in self.train_domain_list])
        self.val_dataloader = torch.utils.data.DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4, pin_memory=True)
        
        self.test_dataset = self.site_dataset_dict[self.test_domain]['test']
        self.test_dataloader = self.site_dataloader_dict[self.test_domain]['test']
        
          
    @staticmethod
    def SingleSite(domain_name, batch_size=16):
        dataset_dict = {
            'train': PACS_SingleDomain(domain_name=domain_name, split='train', train_transform=transform_train).dataset,
            'val': PACS_SingleDomain(domain_name=domain_name, split='val').dataset,
            'test': PACS_SingleDomain(domain_name=domain_name, split='total').dataset,
        }
        dataloader_dict = GetDataLoaderDict(dataset_dict, batch_size)
        return dataloader_dict, dataset_dict
        
    def GetData(self):
        return self.site_dataloader_dict, self.site_dataset_dict
    




