import numpy as np
from torch.utils.data import DataLoader

from oucl.scenarios.samplers import load_sampler
from oucl.scenarios.collate_fns import load_collate_fn

from oucl.scenarios.evaluators import load_evaluator

from tqdm import tqdm


def get_dynamic_inds(dataset, scenario_conf):

    task_classes = scenario_conf.task_classes
    eval_task_classes = scenario_conf.eval_task_classes
    task_sizes = scenario_conf.task_sizes

    stream_ratio, super_ratio, eval_ratio, val_ratio = scenario_conf.stream, scenario_conf.super, scenario_conf.eval, scenario_conf.val
    all_classes = np.concatenate(task_classes)

    all_class_inds = {}
    super_inds = {}
    val_inds = {}
    eval_inds = {}

    for y in all_classes:
        class_inds = np.argwhere(dataset.labels == y).flatten()
        size = len(class_inds)
        all_class_inds[y] = class_inds[:int(size*stream_ratio)].tolist()
        super_inds[y] = class_inds[np.random.choice(int(size*stream_ratio), int(size*super_ratio), replace=False).flatten()].tolist()
        val_inds[y] = class_inds[int(size*stream_ratio):int(size*stream_ratio)+int(size*val_ratio)].tolist()
        eval_inds[y] = class_inds[int(size*stream_ratio)+int(size*val_ratio):int(size*stream_ratio)+int(size*val_ratio)+int(size*eval_ratio)].tolist()


    # stores the indicies for stream, supervision, and evaluation
    stream_inds = []
    task_ids = []

    # stores the progressively larger task super and eval indicie sets
    super_tasks = []
    eval_tasks = []
    val_tasks = []

    for t, task in enumerate(task_classes):
        
        task_inds = []
        for y in task:
            task_inds += all_class_inds[y]

        replace = True if task_sizes[t] > 1 else False
        task_sizes[t] = int(task_sizes[t] * len(task_inds))
        stream_inds += np.random.choice(task_inds, task_sizes[t], replace=replace).flatten().tolist()
        task_ids += np.repeat(t, task_sizes[t]).flatten().tolist()

    for t, task in enumerate(eval_task_classes):
        sup = []
        val = []
        evl = []
        for y in task:
            sup += super_inds[y]
            val += val_inds[y]
            evl += eval_inds[y]
            
        
        super_tasks.append(sup)
        eval_tasks.append(evl)
        val_tasks.append(val)

    eval_iters = []
    pos_so_far = 0
    modifier = 1 / scenario_conf.batch_size
    if scenario_conf.eval_freq > 0:
        for t_len in task_sizes:
            part = int(t_len * modifier) // scenario_conf.eval_freq
            eval_iters += [pos_so_far + (i+1)*part - 1 for i in range(scenario_conf.eval_freq)]
            pos_so_far += int(t_len * modifier)

    task_boundaries = []

    for t_len in task_sizes:
        if len(task_boundaries) > 0:
            bound = int(t_len * modifier) + task_boundaries[-1]
        else:
            bound = int(t_len * modifier)
        task_boundaries.append(bound)

    return np.array(stream_inds), super_tasks, eval_tasks, val_tasks, task_boundaries, eval_iters

def get_mixed_inds(dataset, scenario_conf):

    task_classes = scenario_conf.task_classes
    eval_task_classes = scenario_conf.eval_task_classes
    task_epochs = scenario_conf.task_epochs

    #if not type(task_epochs) == list:
    #    print('HI')
    #    task_epochs = [task_epochs for _ in range(len(task_classes))]
    
    seg_len = scenario_conf.segment_length
    bias_degree = scenario_conf.bias_degree

    stream_ratio, super_ratio, eval_ratio, val_ratio = scenario_conf.stream, scenario_conf.super, scenario_conf.eval, scenario_conf.val
    all_classes = np.concatenate(task_classes)

    all_class_inds = {}
    super_inds = {}
    val_inds = {}
    eval_inds = {}

    class_sizes = []
    for y in all_classes:
        class_inds = np.argwhere(dataset.labels == y).flatten()
        size = len(class_inds)
        all_class_inds[y] = class_inds[:int(size*stream_ratio)].tolist()
        super_inds[y] = class_inds[np.random.choice(int(size*stream_ratio), int(size*super_ratio), replace=False).flatten()].tolist()
        val_inds[y] = class_inds[int(size*stream_ratio):int(size*stream_ratio)+int(size*val_ratio)].tolist()
        eval_inds[y] = class_inds[int(size*stream_ratio)+int(size*val_ratio):int(size*stream_ratio)+int(size*val_ratio)+int(size*eval_ratio)].tolist()
        class_sizes.append(int(size*stream_ratio))

    class_sizes = np.array(class_sizes)

    # stores the indicies for stream, supervision, and evaluation
    stream_inds = []
    task_ids = []

    # stores the progressively larger task super and eval indicie sets
    super_tasks = []
    eval_tasks = []
    val_tasks = []

    task_boundaries = [0]
    size_so_far = 0

    for t, task in enumerate(task_classes):
        task_size = np.sum(class_sizes[task])
        total_segments = int(task_size * task_epochs[t] // seg_len)
        for seg in tqdm(range(total_segments)):
            #seg_classes = np.random.choice(task, np.random.randint(1, bias_degree+1), replace=False)
            seg_classes = np.random.choice(task, bias_degree, replace=False)
            seg_inds_all = []
            for y in seg_classes:
                seg_inds_all += all_class_inds[y]

            seg_inds = []
            if len(seg_inds_all) >= seg_len:
                seg_inds = np.random.choice(seg_inds_all, seg_len, replace=False).flatten().tolist()
                size_so_far += len(seg_inds)
            else:
                while len(seg_inds) < seg_len:
                    np.random.shuffle(seg_inds_all)
                    seg_inds += seg_inds_all.copy()

                seg_inds = seg_inds[:seg_len]
                size_so_far += len(seg_inds)

            stream_inds += seg_inds
            task_ids += np.repeat(1, seg_len).flatten().tolist()
        
        task_boundaries.append(size_so_far // scenario_conf.batch_size)

    for t, task in enumerate(eval_task_classes):
        sup = []
        val = []
        evl = []
        for y in task:
            sup += super_inds[y]
            val += val_inds[y]
            evl += eval_inds[y]
            
        
        super_tasks.append(sup)
        eval_tasks.append(evl)
        val_tasks.append(val)

    eval_iters = []
    if scenario_conf.eval_freq > 1:
        for i in range(1,len(task_boundaries)):
            start = task_boundaries[i-1]
            part = (task_boundaries[i] - task_boundaries[i-1]) // scenario_conf.eval_freq
            for j in range(scenario_conf.eval_freq):
                eval_iters.append(start + ((j+1) * part) - 1)
        del task_boundaries[0]
    else:
        del task_boundaries[0]
        eval_iters = [t_i-1 for t_i in task_boundaries]


    return np.array(stream_inds), super_tasks, eval_tasks, val_tasks, task_boundaries, eval_iters


def load_scenario(dataset, config):

 
    if config.scenario.type == 'dynamic':
        stream_inds, super_inds, eval_inds, val_inds, task_boundaries, eval_iters = \
                            get_dynamic_inds(dataset, config.scenario)
        
    elif config.scenario.type == 'mixed':
        stream_inds, super_inds, eval_inds, val_inds, task_boundaries, eval_iters = \
                            get_mixed_inds(dataset, config.scenario)
    


    collate_fn = load_collate_fn(config.agent.collate, config.agent.stream_transform,
                                 config.dataset.img_size, config.dataset.mean,
                                 config.dataset.std, config.agent)
    
    sampler = load_sampler(stream_inds, config)

    stream = DataLoader(dataset, 
                        batch_size=config.scenario.batch_size,
                        sampler=sampler,
                        collate_fn=collate_fn,
                        shuffle=False,
                        num_workers=config.scenario.num_workers,
                        pin_memory=True,
                        persistent_workers=True,
                        drop_last=True)
    
    stream.eval_iters = eval_iters
    stream.task_boundaries = task_boundaries
    
    evaluator = load_evaluator(dataset, super_inds, eval_inds, val_inds, config)

    return stream, evaluator



        