import numpy as np
import torch


def svd_flex(tensor, svd_string, max_D=None, cutoff=1e-10, sv_right=True,
             sv_vec=None):
    def prod(int_list):
        output = 1
        for num in int_list:
            output *= num
        return output

    with torch.no_grad():
        svd_string = svd_string.replace(' ', '')
        init_str, post_str = svd_string.split('->')
        left_str, right_str = post_str.split(',')

        assert all([c.islower() for c in init_str+left_str+right_str])
        assert len(set(init_str+left_str+right_str)) == len(init_str) + 1
        assert len(set(init_str))+len(set(left_str))+len(set(right_str)) == \
               len(init_str)+len(left_str)+len(right_str)

        bond_char = set(left_str).intersection(set(right_str)).pop()
        left_part = left_str.replace(bond_char, '')
        right_part = right_str.replace(bond_char, '')

        ein_str = f"{init_str}->{left_part+right_part}"
        tensor = torch.einsum(ein_str, [tensor]).contiguous()

        left_shape = list(tensor.shape[:len(left_part)])
        right_shape = list(tensor.shape[len(left_part):])
        left_dim, right_dim = prod(left_shape), prod(right_shape)

        tensor = tensor.view([left_dim, right_dim])

        left_mat, svs, right_mat = torch.svd(tensor)
        svs, _ = torch.sort(svs, descending=True)
        right_mat = torch.t(right_mat)

        if max_D and len(svs) > max_D:
            svs = svs[:max_D]
            left_mat = left_mat[:, :max_D]
            right_mat = right_mat[:max_D]
        elif max_D and len(svs) < max_D:
            copy_svs = torch.zeros([max_D])
            copy_svs[:len(svs)] = svs
            copy_left = torch.zeros([left_mat.size(0), max_D])
            copy_left[:, :left_mat.size(1)] = left_mat
            copy_right = torch.zeros([max_D, right_mat.size(1)])
            copy_right[:right_mat.size(0)] = right_mat
            svs, left_mat, right_mat = copy_svs, copy_left, copy_right

        if sv_vec is not None and svs.shape == sv_vec.shape:
            sv_vec[:] = svs
        elif sv_vec is not None and svs.shape != sv_vec.shape:
            raise TypeError(f"sv_vec.shape must be {list(svs.shape)}, but is "
                            f"currently {list(sv_vec.shape)}")

        truncation = 0
        for s in svs:
            if s < cutoff:
                break
            truncation += 1
        if truncation == 0:
            raise RuntimeError("SVD cutoff too large, attempted to truncate "
                               "tensor to bond dimension 0")

        if max_D:
            svs[truncation:] = 0
            left_mat[:, truncation:] = 0
            right_mat[truncation:] = 0
        else:
            max_D = truncation
            svs = svs[:truncation]
            left_mat = left_mat[:, :truncation]
            right_mat = right_mat[:truncation]

        if sv_right:
            right_mat = torch.einsum('l,lr->lr', [svs, right_mat])
        else:
            left_mat = torch.einsum('lr,r->lr', [left_mat, svs])

        left_tensor = left_mat.view(left_shape+[max_D])
        right_tensor = right_mat.view([max_D]+right_shape)

        if left_str != left_part + bond_char:
            left_tensor = torch.einsum(f"{left_part+bond_char}->{left_str}",
                                    [left_tensor])
        if right_str != bond_char + right_part:
            right_tensor = torch.einsum(f"{bond_char+right_part}->{right_str}",
                                    [right_tensor])

        return left_tensor, right_tensor, truncation

def init_tensor(shape, bond_str, init_method):

    if not isinstance(init_method, str):
        init_str = init_method[0]
        std = init_method[1]
        if init_str == 'min_random_eye':
            init_dim = init_method[2]

        init_method = init_str
    else:
        std = 1e-9

    assert len(shape) == len(bond_str)
    assert len(set(bond_str)) == len(bond_str)

    if init_method not in ['random_eye', 'min_random_eye', 'random_zero']:
        raise ValueError(f"Unknown initialization method: {init_method}")

    if init_method in ['random_eye', 'min_random_eye']:
        bond_chars = ['l', 'r']
        assert all([c in bond_str for c in bond_chars])


        if init_method == 'min_random_eye':
            

            bond_dims = [shape[bond_str.index(c)] for c in bond_chars]
            if all([init_dim <= full_dim for full_dim in bond_dims]):
                bond_dims = [init_dim, init_dim]
            else:
                init_dim = min(bond_dims)

            eye_shape = [init_dim if c in bond_chars else 1 for c in bond_str]
            expand_shape = [init_dim if c in bond_chars else shape[i]
                            for i, c in enumerate(bond_str)]

        elif init_method == 'random_eye':
            eye_shape = [shape[i] if c in bond_chars else 1
                         for i, c in enumerate(bond_str)]
            expand_shape = shape
            bond_dims = [shape[bond_str.index(c)] for c in bond_chars]

        eye_tensor = torch.eye(bond_dims[0], bond_dims[1]).view(eye_shape)
        eye_tensor = eye_tensor.expand(expand_shape)

        tensor = torch.zeros(shape)
        tensor[[slice(dim) for dim in expand_shape]] = eye_tensor

        tensor += std * torch.randn(shape)

    elif init_method == 'random_zero':
        tensor = std * torch.randn(shape)

    return tensor



def onehot(labels, max_value):
    label_vecs = torch.zeros([len(labels), max_value])

    for i, label in enumerate(labels):
        label_vecs[i, label] = 1.

    return label_vecs

def joint_shuffle(input_data, input_labels):
    assert input_data.is_cuda == input_labels.is_cuda
    use_gpu = input_data.is_cuda
    if use_gpu:
        input_data, input_labels = input_data.cpu(), input_labels.cpu()

    data, labels = input_data.numpy(), input_labels.numpy()

    np.random.seed(0)
    np.random.shuffle(data)
    np.random.seed(0)
    np.random.shuffle(labels)

    data, labels = torch.from_numpy(data), torch.from_numpy(labels)
    if use_gpu:
        data, labels = data.cuda(), labels.cuda()

    return data, labels

def load_HV_data(length):
    num_images = 4 * (2**(length-1) - 1)
    num_patterns = num_images // 2
    split = num_images // 4

    if length > 14:
        print("load_HV_data will generate {} images, "
              "this could take a while...".format(num_images))

    images = np.empty([num_images,length,length], dtype=np.float32)
    labels = np.empty(num_images, dtype=np.int)

    template = "{:0" + str(length) + "b}"

    for i in range(1, num_patterns+1):
        pattern = template.format(i)
        pattern = [int(s) for s in pattern]

        for j, val in enumerate(pattern):
            images[2*i-2, j, :] = val
            images[2*i-1, :, j] = val

        labels[2*i-2] = 0
        labels[2*i-1] = 1

    np.random.seed(0)
    np.random.shuffle(images)
    np.random.seed(0)
    np.random.shuffle(labels)

    train_images, train_labels = images[split:], labels[split:]
    test_images, test_labels = images[:split], labels[:split]

    return torch.from_numpy(train_images), \
           torch.from_numpy(train_labels), \
           torch.from_numpy(test_images), \
           torch.from_numpy(test_labels)
