import collections

import numpy as np
import robomimic.utils.tensor_utils as TensorUtils
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, RandomSampler
import torch.distributed as dist

from libero.lifelong.algos.base import Sequential
from libero.lifelong.datasets import TruncatedSequenceDataset, TruncatedSequenceDatasetSample, MySequenceDataset
from libero.lifelong.utils import *


def cycle(dl):
    while True:
        for data in dl:
            yield data


def merge_datas(x, y):
    if isinstance(x, (dict, collections.OrderedDict)):
        if isinstance(x, dict):
            new_x = dict()
        else:
            new_x = collections.OrderedDict()

        for k in x.keys():
            new_x[k] = merge_datas(x[k], y[k])
        return new_x
    elif isinstance(x, torch.FloatTensor) or isinstance(x, torch.LongTensor):
        return torch.cat([x, y], 0)


class ER(Sequential):
    """
    The experience replay policy.
    """

    def __init__(self, n_tasks, cfg, **policy_kwargs):
        super().__init__(n_tasks=n_tasks, cfg=cfg, **policy_kwargs)
        # we truncate each sequence dataset to a buffer, when replay is used,
        # concate all buffers to form a single replay buffer for replay.
        self.buffer_dataset = MySequenceDataset()
        self.descriptions = []
        self.buffer = None

    def start_task(self, task):
        super().start_task(task)
        if len(self.buffer_dataset) > 0:
            self.buffer = cycle(
                DataLoader(
                    self.buffer_dataset,
                    batch_size=self.cfg.train.batch_size,
                    num_workers=self.cfg.train.num_workers,
                    sampler=RandomSampler(self.buffer_dataset),
                    persistent_workers=True,
                )
            )

    def end_task(self, dataset, task_id, benchmark):
        self.buffer_dataset = self.buffer_dataset + dataset[:self.cfg.lifelong.n_demos]


    def observe(self, data):
        if self.buffer is not None:
            buf_data = next(self.buffer)
            data = merge_datas(data, buf_data)

        data = self.map_tensor_to_device(data)

        self.optimizer.zero_grad()
        loss = self.policy.compute_loss(data)
        (self.loss_scale * loss).backward()
        if self.cfg.train.grad_clip is not None:
            grad_norm = nn.utils.clip_grad_norm_(
                self.policy.parameters(), self.cfg.train.grad_clip
            )
            
        self.average_gradients()
        
        self.optimizer.step()
        return loss.item()
    
    def average_gradients(self):
        """ Gradient averaging. """
        size = float(dist.get_world_size())
        for param in self.policy.parameters():
            if param.grad is not None:
                dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                param.grad.data /= size