import torch.nn.functional as F
import random
import torch
import numpy as np

from torch.utils.data import Dataset
from collections import namedtuple
from torchmeta.datasets import Omniglot, MiniImagenet
from torchmeta.toy import Sinusoid
from torchmeta.transforms import ClassSplitter, Categorical, Rotation
from torchvision.transforms import ToTensor, Resize, Compose

from maml.model import ModelConvOmniglot, ModelConvMiniImagenet, ModelMLPSinusoid
from maml.utils import ToTensor1D
from torch.utils.data import DataLoader


Benchmark = namedtuple('Benchmark', 'meta_train_dataset meta_val_dataset '
                                    'meta_test_dataset model loss_function')

def get_benchmark_by_name(dataset,
                          folder,
                          num_ways,
                          num_shots,
                          num_shots_test,
                          hidden_size=None,
                          seed=666):

    if dataset == 'sinusoid':
        transform = ToTensor1D()

        dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
        
        val_dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=15)

        meta_train_dataset = Sinusoid(num_shots + num_shots_test,
                                      num_tasks=1000000,
                                      transform=transform,
                                      target_transform=transform,
                                      dataset_transform=dataset_transform)
        meta_train_dataset.seed(seed)
        
        meta_val_dataset = Sinusoid(num_shots + 15,
                                    num_tasks=1000000,
                                    transform=transform,
                                    target_transform=transform,
                                    dataset_transform=val_dataset_transform)
        meta_val_dataset.seed(seed)

        meta_test_dataset = Sinusoid(num_shots + num_shots_test,
                                     num_tasks=1000000,
                                     transform=transform,
                                     target_transform=transform,
                                     dataset_transform=dataset_transform)
        meta_test_dataset.seed(seed)

        model = ModelMLPSinusoid(hidden_sizes=[40, 40])
        loss_function = F.mse_loss

    elif dataset == 'omniglot':
        dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=num_shots_test)
        val_dataset_transform = ClassSplitter(shuffle=True,
                                      num_train_per_class=num_shots,
                                      num_test_per_class=15)
        
        class_augmentations = [Rotation([90, 180, 270])]
        transform = Compose([Resize(28), ToTensor()])
    
        meta_train_dataset = Omniglot(folder,
                                      transform=transform,
                                      target_transform=Categorical(num_ways),
                                      num_classes_per_task=num_ways,
                                      meta_train=True,
                                      class_augmentations=class_augmentations,
                                      dataset_transform=dataset_transform,
                                      download=True)
        meta_train_dataset.seed(seed)
        meta_val_dataset = Omniglot(folder,
                                    transform=transform,
                                    target_transform=Categorical(num_ways),
                                    num_classes_per_task=num_ways,
                                    meta_val=True,
                                    class_augmentations=class_augmentations,
                                    dataset_transform=val_dataset_transform)
        meta_val_dataset.seed(seed)
        meta_test_dataset = Omniglot(folder,
                                     transform=transform,
                                     target_transform=Categorical(num_ways),
                                     num_classes_per_task=num_ways,
                                     meta_test=True,
                                     dataset_transform=dataset_transform)
        meta_test_dataset.seed(seed)

        model = ModelConvOmniglot(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    elif dataset == 'miniimagenet':
        transform = Compose([Resize(84), ToTensor()])

        meta_train_dataset = MiniImagenet(folder,
                                          transform=transform,
                                          target_transform=Categorical(num_ways),
                                          num_classes_per_task=num_ways,
                                          meta_train=True,
                                          dataset_transform=dataset_transform,
                                          download=True)
        meta_val_dataset = MiniImagenet(folder,
                                        transform=transform,
                                        target_transform=Categorical(num_ways),
                                        num_classes_per_task=num_ways,
                                        meta_val=True,
                                        dataset_transform=dataset_transform)
        meta_test_dataset = MiniImagenet(folder,
                                         transform=transform,
                                         target_transform=Categorical(num_ways),
                                         num_classes_per_task=num_ways,
                                         meta_test=True,
                                         dataset_transform=dataset_transform)

        model = ModelConvMiniImagenet(num_ways, hidden_size=hidden_size)
        loss_function = F.cross_entropy

    else:
        raise NotImplementedError('Unknown dataset `{0}`.'.format(dataset))

    return Benchmark(meta_train_dataset=meta_train_dataset,
                     meta_val_dataset=meta_val_dataset,
                     meta_test_dataset=meta_test_dataset,
                     model=model,
                     loss_function=loss_function)

def GetTaskPool(dataloader, num_tasks):
    task_pool = []

    for batch_idx, batch in enumerate(dataloader):
        if batch_idx >= num_tasks:
            break
        task_pool.append(batch) 
        data, target = batch

    return task_pool

class TaskPoolDataset(Dataset):
    def __init__(self, task_pool):
        self.task_pool = task_pool

    def __len__(self):
        return len(self.task_pool)

    def __getitem__(self, idx):
        task = self.task_pool[idx]
        # 直接返回一个包含 'train' 和 'test' 键的字典
        return {
            'train': task['train'],
            'test': task['test']
        }

class MetaProxTaskPoolDataset(Dataset):
    
    def __init__(self, task_pool):
        # 存储任务池，任务池是一个包含多个任务的列表
        self.task_pool = task_pool

    def __len__(self):
        # 返回任务池的大小，即任务的数量
        return len(self.task_pool)

    def __getitem__(self, idx):
        # 获取指定任务
        task = self.task_pool[idx]

        # 调试信息
        # print(f"Task structure: {task}")
        # print(f"Train set structure: {task['train']}")
        # print(f"Test set structure: {task['test']}")
        train_set = task['train']  # 假设是列表
        test_set = task['test']    # 假设是列表
        
         # 分别提取 inputs 和 labels
         # 从支持集和查询集提取数据和标签
        train_data, train_labels = train_set[0], train_set[1]
        test_data, test_labels = test_set[0], test_set[1]

        # 合并数据和标签
        combined_data = torch.cat([train_data, test_data], dim=1)  # 合并图像数据
        combined_labels = torch.cat([train_labels, test_labels], dim=1)  # 合并标签
        
        # 返回合并后的数据集
        combined_set = {
            'data': combined_data,
            'labels': combined_labels
        }

        return combined_data,combined_labels