import os

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import ConcatDataset, Subset
import torch.distributed as dist
from dataset.cglm import CGLM


def get_random_idx(l1, l2):
    # get unique random index for seq with length l1 with max length l2
    return torch.randperm(l1)[:min(l1, l2)].unique().long()


class trans_label(object):
    def __init__(self, task, args):
        self.task = task
        self.args = args
        self.pre_classes = sum([len(os.listdir(f'{args.data}/{i}/val/')) for i in range(task)])

    def __call__(self, y):
        return y + self.pre_classes


def build_transform(is_train, args):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # train transform
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = transforms.Compose(
            [transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
        return transform

    # eval transform
    t = []
    if args.input_size <= 224:
        crop_pct = 224 / 256
    else:
        crop_pct = 1.0
    size = int(args.input_size / crop_pct)
    t.append(
        transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
        # to maintain same ratio w.r.t. 224 images
    )
    t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())
    t.append(normalize)
    return transforms.Compose(t)


def get_val_set(args, task):
    if args.dataset in ['ImageNet10k', 'ImageNet2k']:
        valdir = f'{args.data}/{task}/val/'

        return datasets.ImageFolder(valdir, build_transform(False, args), target_transform=trans_label(task, args))
    elif args.dataset == 'cglm':
        valdir = f'{args.data}/test.csv'
        return CGLM(valdir, task, is_train=False, transform=build_transform(False, args),
                    timefile=f'{args.data}/{args.split}time.txt')
    else:
        raise ValueError


def get_val_loader(args, task):
    val_set = get_val_set(args, task)
    if args.dist_eval:
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_set)
    else:
        val_sampler = None
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, pin_memory=False, sampler=val_sampler)
    return val_loader


def get_unlabeled_set(args, task):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0),
                                     interpolation=transforms.InterpolationMode.BICUBIC),  # 3 is bicubic
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    print('=> Loading data')
    if args.dataset in ['ImageNet10k', 'ImageNet2k']:
        dataset = datasets.ImageFolder(f'{args.data}/{task}/train', transform=transform)
        unlabel_index = torch.load(f'{args.data}/{task}/{args.label_ratio}labeled/unlabeled_task{task}.buf')
        return Subset(dataset, unlabel_index)
    elif args.dataset == 'cglm':
        file = f'{args.data}/{task}_unlabeled.csv'
        dataset = CGLM(file, task, is_train=True, transform=transform)
        return dataset
    else:
        raise ValueError

def get_cur_labeled_set(args,task):
    if args.dataset in ['ImageNet10k', 'ImageNet2k']:
        cur_folder = f'{args.data}/{task}/train/'
        label_index = torch.load(f'{args.data}/{task}/{args.label_ratio}labeled/labeled_task{task}.buf')
        cur_full_set = datasets.ImageFolder(cur_folder, build_transform(True, args),
                                            target_transform=trans_label(task, args))
        cur_set = Subset(cur_full_set, label_index)
    elif args.dataset == 'cglm':
        file = f'{args.data}/{task}_labeled.csv'
        cur_set = CGLM(file, task, is_train=True, transform=build_transform(True, args))
    else:
        raise ValueError
    return cur_set


def get_labeled_set(args, task, replay_first=None):
    # get labeled set for current task and update memory
    if replay_first is None:
        replay_first = args.replay_first
    cur_set = get_cur_labeled_set(args,task)
    if replay_first:
        memory = update_memory(args, task, cur_set)
        mem_set = get_mem_set(args, task,memory=memory)
    else:
        mem_set = get_mem_set(args, task)
        update_memory(args, task, cur_set)
    if args.cur_task_separate_men_set:
        return cur_set, ConcatDataset(mem_set)
    else:
        mem_set.append(cur_set)
        return ConcatDataset(mem_set), None


def get_mem_set(args, task,memory=None):
    # get memory set according to the memory buffer
    if args.size_replay_buffer == 0 or task == 0:
        return []
    mem_sets = []
    if memory is None:
        memory = torch.load(f'{args.output_dir}/buffer.buf')

    for i in range(len(memory)):
        seti = get_cur_labeled_set(args, i)
        mem_sets.append(Subset(seti, memory[i].unique()))
    return mem_sets


def update_memory(args, task, cur_set):
    if args.size_replay_buffer == -1:
        args.size_replay_buffer = 1e12

    cur_size = len(cur_set)
    if task == 0:
        size_per_task = min(cur_size, int(args.size_replay_buffer))

    else:
        last_mem = torch.load(f'{args.output_dir}/buffer.buf')
        if args.size_replay_buffer - last_mem.nelement() >= min(cur_size, last_mem.shape[1]):
            size_per_task = int(min(cur_size, last_mem.shape[1]))
        else:
            size_per_task = int(args.size_replay_buffer // (task + 1))
    memory = torch.zeros([task + 1, size_per_task], dtype=torch.long)
    memory[task] = get_random_idx(cur_size, size_per_task)
    for i in range(task):
        last_idx = last_mem[i].unique()
        last_size = len(last_idx)
        memory[i] = last_idx[get_random_idx(last_size, size_per_task)]

    memory = memory.cuda(args.gpu)
    handle = dist.broadcast(memory, src=0,async_op=True)
    handle.wait()

    memory = memory.cpu()



    if args.gpu == 0:
        torch.save(memory, f'{args.output_dir}/buffer.buf')

    torch.distributed.barrier()
    torch.cuda.synchronize(0)
    print('gpu', args.gpu, memory, force=True)

    return memory
