import os
import io
from glob import glob
from collections import OrderedDict, defaultdict
import math

import numpy as np
import torch
import torch.distributed as dist
import torchvision.utils as vutils
import torchvision.transforms.functional as TF
import PIL.Image
from mpi4py import MPI


# Note! This is l2 square, not l2
def l2(a, b):
    return torch.pow(torch.abs(a - b), 2).sum(dim=1)


# required when we load optimizer from a checkpoint
def optimizer_cuda(optimizer, device):
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)


def get_ckpt_path(base_dir, ckpt_num):
    if ckpt_num is None:
        return get_recent_ckpt_path(base_dir)
    files = glob(os.path.join(base_dir, "*.pt"))
    for f in files:
        if "ckpt_%08d.pt" % ckpt_num in f:
            return f, ckpt_num
    raise Exception("Did not find ckpt_%s.pt" % ckpt_num)


def get_recent_ckpt_path(base_dir):
    files = glob(os.path.join(base_dir, "*.pt"))
    files.sort()
    if len(files) == 0:
        return None, None
    max_step = max([f.rsplit("_", 1)[-1].split(".")[0] for f in files])
    paths = [f for f in files if max_step in f]
    if len(paths) == 1:
        return paths[0], int(max_step)
    else:
        raise Exception("Multiple most recent ckpts %s" % paths)


def image_grid(image, n=4):
    return vutils.make_grid(image[:n], nrow=n).cpu().detach().numpy()


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def slice_tensor(input, indices):
    ret = {}
    for k, v in input.items():
        ret[k] = v[indices]
    return ret


def average_gradients(model):
    size = float(dist.get_world_size())
    for p in model.parameters():
        if p.grad is not None:
            dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
            p.grad.data /= size


def ensure_shared_grads(model, shared_model):
    """for A3C"""
    for param, shared_param in zip(model.parameters(), shared_model.parameters()):
        if shared_param.grad is not None:
            return
        shared_param._grad = param.grad


def compute_gradient_norm(model):
    grad_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            grad_norm += (p.grad.data ** 2).sum().item()
    return grad_norm


def compute_weight_norm(model):
    weight_norm = 0
    for p in model.parameters():
        if p.data is not None:
            weight_norm += (p.data ** 2).sum().item()
    return weight_norm


def compute_weight_sum(model):
    weight_sum = 0
    for p in model.parameters():
        if p.data is not None:
            weight_sum += p.data.abs().sum().item()
    return weight_sum


# sync_networks across the different cores
def sync_networks(network):
    """
    netowrk is the network you want to sync
    """
    comm = MPI.COMM_WORLD
    if comm.Get_size() == 1:
        return
    flat_params, params_shape = _get_flat_params(network)
    comm.Bcast(flat_params, root=0)
    # set the flat params back to the network
    _set_flat_params(network, params_shape, flat_params)


# get the flat params from the network
def _get_flat_params(network):
    param_shape = {}
    flat_params = None
    for key_name, value in network.named_parameters():
        param_shape[key_name] = value.cpu().detach().numpy().shape
        if flat_params is None:
            flat_params = value.cpu().detach().numpy().flatten()
        else:
            flat_params = np.append(flat_params, value.cpu().detach().numpy().flatten())
    return flat_params, param_shape


# set the params from the network
def _set_flat_params(network, params_shape, params):
    pointer = 0
    if hasattr(network, "_config"):
        device = network._config.device
    else:
        device = torch.device("cpu")

    for key_name, values in network.named_parameters():
        # get the length of the parameters
        len_param = np.prod(params_shape[key_name])
        copy_params = params[pointer : pointer + len_param].reshape(
            params_shape[key_name]
        )
        copy_params = torch.tensor(copy_params).to(device)
        # copy the params
        values.data.copy_(copy_params.data)
        # update the pointer
        pointer += len_param

def weights_init(m):
    for layer in m.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()
    # if isinstance(m, torch.nn.Conv2d):
    #     torch.nn.init.xavier_uniform(m.weight.data)


# sync gradients across the different cores
def sync_grads(network):
    comm = MPI.COMM_WORLD
    if comm.Get_size() == 1:
        return
    flat_grads, grads_shape = _get_flat_grads(network)
    global_grads = np.zeros_like(flat_grads)
    comm.Allreduce(flat_grads, global_grads, op=MPI.SUM)
    _set_flat_grads(network, grads_shape, global_grads)


def _set_flat_grads(network, grads_shape, flat_grads):
    pointer = 0
    if hasattr(network, "_config"):
        device = network._config.device
    else:
        device = torch.device("cpu")

    for key_name, value in network.named_parameters():
        if key_name in grads_shape:
            len_grads = np.prod(grads_shape[key_name])
            copy_grads = flat_grads[pointer : pointer + len_grads].reshape(
                grads_shape[key_name]
            )
            copy_grads = torch.tensor(copy_grads).to(device)
            # copy the grads
            value.grad.data.copy_(copy_grads.data)
            pointer += len_grads


def _get_flat_grads(network):
    grads_shape = {}
    flat_grads = None
    for key_name, value in network.named_parameters():
        try:
            grads_shape[key_name] = value.grad.data.cpu().numpy().shape
        except:
            print("Cannot get grad of tensor {}".format(key_name))
            continue

        if flat_grads is None:
            flat_grads = value.grad.data.cpu().numpy().flatten()
        else:
            flat_grads = np.append(flat_grads, value.grad.data.cpu().numpy().flatten())
    return flat_grads, grads_shape


def fig2tensor(draw_func):
    def decorate(*args, **kwargs):
        tmp = io.BytesIO()
        fig = draw_func(*args, **kwargs)
        fig.savefig(tmp, dpi=88)
        tmp.seek(0)
        fig.clf()
        return TF.to_tensor(PIL.Image.open(tmp))

    return decorate


def tensor2np(t):
    if isinstance(t, torch.Tensor):
        return t.clone().detach().cpu().numpy()
    else:
        return t


def tensor2img(tensor):
    if len(tensor.shape) == 4:
        assert tensor.shape[0] == 1
        tensor = tensor.squeeze(0)
    img = tensor.permute(1, 2, 0).detach().cpu().numpy()
    import cv2

    cv2.imwrite("tensor.png", img)


def obs2tensor(obs, device):
    if isinstance(obs, list):
        obs = list2dict(obs)

    return OrderedDict(
        [
            (k, torch.tensor(np.stack(v), dtype=torch.float32).to(device))
            for k, v in obs.items()
        ]
    )

def mse_dict_tensor(dict1, dict2):
    if isinstance(dict1, OrderedDict):
        ac = list(dict1.values())
        if len(ac[0].shape) == 1:
            ac = [x.unsqueeze(0) for x in ac]
        ac = torch.cat(ac, dim=-1)

    if isinstance(dict2, OrderedDict):
        pred_ac = list(dict2.values())
        if len(pred_ac[0].shape) == 1:
            pred_ac = [x.unsqueeze(0) for x in pred_ac]
        pred_ac = torch.cat(pred_ac, dim=-1)

    diff = ac - pred_ac
    mse = diff.pow(2).mean()
    return mse



# transfer a numpy array into a tensor
def to_tensor(x, device):
    if isinstance(x, dict):
        return OrderedDict(
            [(k, torch.as_tensor(v, device=device).float()) for k, v in x.items()]
        )
    if isinstance(x, list):
        if isinstance(x[0], dict):
            tensor_data = []
            for v_ in x:
                tensor_data.append(OrderedDict([(k, torch.as_tensor(v, device=device).float()) for k, v in v_.items()]))
            return tensor_data
        else:
            return [torch.as_tensor(v, device=device).float() for v in x]
    return torch.as_tensor(x, device=device).float()


def list2dict(rollout):
    ret = OrderedDict()
    for k in rollout[0].keys():
        ret[k] = []
    for transition in rollout:
        for k, v in transition.items():
            ret[k].append(v)
    return ret


def scale_dict_tensor(tensor, scalar):
    if isinstance(tensor, dict):
        return OrderedDict(
            [(k, scale_dict_tensor(tensor[k], scalar)) for k in tensor.keys()]
        )
    elif isinstance(tensor, list):
        return [scale_dict_tensor(tensor[i], scalar) for i in range(len(tensor))]
    else:
        return tensor * scalar


# From softlearning repo
def flatten(unflattened, parent_key="", separator="/"):
    items = []
    for k, v in unflattened.items():
        if separator in k:
            raise ValueError("Found separator ({}) from key ({})".format(separator, k))
        new_key = parent_key + separator + k if parent_key else k
        if isinstance(v, collections.MutableMapping) and v:
            items.extend(flatten(v, new_key, separator=separator).items())
        else:
            items.append((new_key, v))

    return OrderedDict(items)

# dict-list to list-dict
def process_transition(idx, new_transitions):
    """Abstract the repeated logic for processing transitions."""
    data = []
    for i in idx:
        processed_transition = {}
        for k, v in new_transitions.items():
            if isinstance(v, dict):
                processed_transition[k] = [{key: value[i] for key, value in v.items()}]
            else:
                processed_transition[k] = [v[i]]
        data.append(processed_transition)
    return data

# list-dict to dict-list

# From softlearning repo
def unflatten(flattened, separator="."):
    result = {}
    for key, value in flattened.items():
        parts = key.split(separator)
        d = result
        for part in parts[:-1]:
            if part not in d:
                d[part] = {}
            d = d[part]
        d[parts[-1]] = value

    return result


# from https://github.com/MishaLaskin/rad/blob/master/utils.py
def center_crop(img, out=84):
    """
        args:
        imgs: np.array shape (C,H,W)
        out: output size (e.g. 84)
        returns np.array shape (1,C,H,W)
    """
    h, w = img.shape[1:]
    new_h, new_w = out, out

    top = (h - new_h) // 2
    left = (w - new_w) // 2

    img = img[:, top : top + new_h, left : left + new_w]
    img = np.expand_dims(img, axis=0)
    return img

# from https://github.com/MishaLaskin/rad/blob/master/utils.py
def center_crop_images(image, out=84):
    """
        args:
        imgs: np.array shape (B,C,H,W)
        out: output size (e.g. 84)
        returns np.array shape (B,C,H,W)
    """
    h, w = image.shape[2:]
    new_h, new_w = out, out

    top = (h - new_h) // 2
    left = (w - new_w) // 2

    image = image[:, :, top:top + new_h, left:left + new_w]

    # bs = image.shape[0]
    # for i in range(0, bs, bs//10):
    #     image_temp  = PIL.Image.fromarray((np.transpose(image[i].detach().cpu().numpy(), (1, 2, 0))* 255).astype(np.uint8))
    #     image_temp.save(f"crop_image_{i+1}.png")

    return image


# from https://github.com/MishaLaskin/rad/blob/master/data_augs.py
def random_crop(imgs, out=84):
    """
        args:
        imgs: np.array shape (B,C,H,W)
        out: output size (e.g. 84)
        returns np.array
    """
    b, c, h, w = imgs.shape
    crop_max = h - out + 1
    w1 = np.random.randint(0, crop_max, b)
    h1 = np.random.randint(0, crop_max, b)
    cropped = np.empty((b, c, out, out), dtype=imgs.dtype)
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        cropped[i] = img[:, h11 : h11 + out, w11 : w11 + out]
    return cropped


def sample_from_dataloader(dataloader, num_samples, batch_size):
    data_batches = []
    for _ in range(math.ceil(num_samples / batch_size)):
        try:
            data_batches.append(next(iter(dataloader)))
        except StopIteration:
            data_iter = iter(dataloader)
            data_batches.append(next(data_iter))
        # try:
        #     data_batches.append(next(self._data_iter[task_id]))
        # except StopIteration:
        #     self._data_iter[task_id] = iter(self._data_loader[task_id])
        #     data_batches.append(next(self._data_iter[task_id]))

    if num_samples == batch_size:
        return data_batches[0]

    expert_data_pos = {}
    # Append all data batches
    for k, v in data_batches[0].items():
        if isinstance(v, dict):
            sub_keys = v.keys()
            expert_data_pos[k] = {sub_key: np.concatenate([b[k][sub_key] for b in data_batches]) for sub_key in
                                  sub_keys}
        else:
            expert_data_pos[k] = np.concatenate([b[k] for b in data_batches])

    # Return num_samples data
    for k, v in expert_data_pos.items():
        if isinstance(v, dict):
            sub_keys = v.keys()
            expert_data_pos[k] = {sub_key: v[sub_key][:num_samples] for sub_key in sub_keys}
        else:
            expert_data_pos[k] = v[:num_samples]
    return expert_data_pos

import random
def apply_random_augmentations(images, device='cpu', p=0.5):
    """
    Randomly applies 0 to 3 augmentations on the images.

    Args:
        images (torch.Tensor): The input images with shape [B, C, H, W].
        device (torch.device): The device on which the images are stored.
        p (float): Probability of applying each augmentation.

    Returns:
        torch.Tensor: Augmented images.
    """
    augmentations = [
        lambda x: random_flip(x, device, p),
        lambda x: random_rotation(x, device, p),
        lambda x: random_cutout(x, device)
    ]
    
    # Randomly decide how many augmentations to apply (0 to 3)
    num_augmentations_to_apply = random.randint(1, len(augmentations))
    
    # Randomly select the augmentations to apply
    selected_augmentations = random.sample(augmentations, num_augmentations_to_apply)
    
    # Apply the selected augmentations
    for aug in selected_augmentations:
        images = aug(images)
    
    return images

# original code from https://github.com/MishaLaskin/rad/blob/master/Data%20Aug%20Visualization.ipynb
def random_flip(images, device, p=.5):
    # images: [B, C, H, W]
    bs, channels, h, w = images.shape

    flipped_images = images.flip([3])
    
    rnd = np.random.uniform(0., 1., size=(images.shape[0],))
    mask = rnd <= p
    mask = torch.from_numpy(mask)
    frames = images.shape[1] #// 3
    images = images.view(*flipped_images.shape)
    mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype) # [bs, 1] * [1, frames] = [bs, frames], for the first data in this batch, all frames are flipped in the same way
    
    mask = mask.type(images.dtype).to(device)
    mask = mask[:, :, None, None]
    
    out = mask * flipped_images + (1 - mask) * images

    out = out.view([bs, -1, h, w])

    # for i in range(bs):
    #     image_temp  = PIL.Image.fromarray((np.transpose(out[i].detach().cpu().numpy(), (1, 2, 0))* 255).astype(np.uint8))
    #     image_temp.save(f"flip_image_{i+1}.png")

    return out

def random_rotation(images, device, p=.5):
    # images: [B, C, H, W]
    bs, channels, h, w = images.shape
    # Rotated images
    rot90_images = images.rot90(1,[2,3])
    rot180_images = images.rot90(2,[2,3])
    rot270_images = images.rot90(3,[2,3])    
    # Random rotation decisions
    rnd = np.random.uniform(0., 1., size=(images.shape[0],))
    rnd_rot = np.random.randint(1, 4, size=(images.shape[0],))
    mask = rnd <= p
    mask = rnd_rot * mask
    mask = torch.from_numpy(mask).to(device)
    
    frames = images.shape[1]
    masks = [torch.zeros_like(mask) for _ in range(4)]
    for i,m in enumerate(masks):
        m[torch.where(mask==i)] = 1
        m = m[:, None] * torch.ones([1, frames]).type(mask.dtype).type(images.dtype).to(device)
        m = m[:,:,None,None]
        masks[i] = m
    
    
    out = masks[0] * images + masks[1] * rot90_images + masks[2] * rot180_images + masks[3] * rot270_images
    out = out.view([bs, -1, h, w])

    # for i in range(0, bs, bs//10):
    #     image_temp  = PIL.Image.fromarray((np.transpose(out[i].detach().cpu().numpy(), (1, 2, 0))* 255).astype(np.uint8))
    #     image_temp.save(f"rotation_image_{i+1}.png")

    return out
    
def random_cutout(imgs, device, min_cut=10, max_cut=30):
    """
        args:
        imgs: shape (B,C,H,W)
        out: output size (e.g. 84)
    """
    n, c, h, w = imgs.shape
    w1 = np.random.randint(min_cut, max_cut, n)
    h1 = np.random.randint(min_cut, max_cut, n)
    
    cutouts = imgs.clone()
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        cutouts[i, :, h11:h11 + h11, w11:w11 + w11] = 0
    
    # for i in range(0, n, n//10):
    #     image_temp  = PIL.Image.fromarray((np.transpose(cutouts[i].detach().cpu().numpy(), (1, 2, 0))* 255).astype(np.uint8))
    #     image_temp.save(f"cut_image_{i+1}.png")

    return cutouts