
import math
import torch
import math
import torch.nn as nn
from utils import init_tensor, svd_flex
from contractables import SingleMat, MatRegion, OutputCore, ContractableList, \
                          EdgeVec, OutputMat

import torch.nn.functional as F 
import torch


class MPS(nn.Module):

    def __init__(self, input_dim, output_dim, bond_dim, feature_dim=2,
                 adaptive_mode=False, periodic_bc=False, parallel_eval=False,
                 label_site=None, path=None, init_std=1e-9, use_bias=True,
                 fixed_bias=True, cutoff=1e-10, merge_threshold=2000):
        super().__init__()
        if label_site is None:
            label_site = input_dim // 2
        assert label_site >= 0 and label_site <= input_dim

        if adaptive_mode:
            use_bias = False

        module_list = []
        init_args = {'bond_str': 'slri',
                     'shape': [label_site, bond_dim, bond_dim, feature_dim],
                     'init_method': ('min_random_eye' if adaptive_mode else
                     'random_zero', init_std, output_dim)}

        if label_site > 0:
            tensor = init_tensor(**init_args)

            module_list.append(InputRegion(tensor, use_bias=use_bias, 
                                           fixed_bias=fixed_bias))

        tensor = init_tensor(shape=[output_dim, bond_dim, bond_dim],
            bond_str='olr', init_method=('min_random_eye' if adaptive_mode else
                                         'random_eye', init_std, output_dim))
        module_list.append(OutputSite(tensor))

        if label_site < input_dim:
            init_args['shape'] = [input_dim-label_site, bond_dim, bond_dim, 
                                  feature_dim]
            tensor = init_tensor(**init_args)
            module_list.append(InputRegion(tensor, use_bias=use_bias, 
                                           fixed_bias=fixed_bias))

        if adaptive_mode:
            self.linear_region = MergedLinearRegion(module_list=module_list,
                                 periodic_bc=periodic_bc,
                                 parallel_eval=parallel_eval, cutoff=cutoff,
                                 merge_threshold=merge_threshold)

            self.bond_list = bond_dim * torch.ones(input_dim + 2, 
                                                   dtype=torch.long)
            if not periodic_bc:
                self.bond_list[0], self.bond_list[-1] = 1, 1

            self.sv_list = -1. * torch.ones([input_dim + 2, bond_dim])

        else:
            self.linear_region = LinearRegion(module_list=module_list,
                                 periodic_bc=periodic_bc,
                                 parallel_eval=parallel_eval)
        assert len(self.linear_region) == input_dim

        if path:
            assert isinstance(path, (list, torch.Tensor))
            assert len(path) == input_dim

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.bond_dim = bond_dim
        self.feature_dim = feature_dim
        self.periodic_bc = periodic_bc
        self.adaptive_mode = adaptive_mode
        self.label_site = label_site
        self.path = path
        self.use_bias = use_bias
        self.fixed_bias = fixed_bias
        self.cutoff = cutoff
        self.merge_threshold = merge_threshold
        self.feature_map = None

    def forward(self, input_data):

        if self.path:
            path_inputs = []
            for site_num in self.path:
                path_inputs.append(input_data[:, site_num])
            input_data = torch.stack(path_inputs, dim=1)

        input_data = input_data
        output, QELoss = self.linear_region(input_data)

        if isinstance(output, tuple):
            output, new_bonds, new_svs = output

            assert len(new_bonds) == len(self.bond_list)
            assert len(new_bonds) == len(new_svs)
            for i, bond_dim in enumerate(new_bonds):
                if bond_dim != -1:
                    assert new_svs[i] is not -1
                    self.bond_list[i] = bond_dim
                    self.sv_list[i] = new_svs[i]
        return output, QELoss

    def embed_input(self, input_data):

        assert len(input_data.shape) in [2, 3]
        assert input_data.size(1) == self.input_dim

        if len(input_data.shape) == 3:
            if input_data.size(2) != self.feature_dim:
                raise ValueError(f"input_data has wrong shape to be unembedded "
                "or pre-embedded data (input_data.shape = "
                f"{list(input_data.shape)}, feature_dim = {self.feature_dim})")
            return input_data

        embedded_shape = list(input_data.shape) + [self.feature_dim]

        if self.feature_map is not None:
            f_map = self.feature_map
            embedded_data = torch.stack([torch.stack([f_map(x) for x in batch])
                                                      for batch in input_data])

            assert embedded_data.shape == torch.Size(
                   [input_data.size(0), self.input_dim, self.feature_dim])

        else:
            if self.feature_dim != 2:
                raise RuntimeError(f"self.feature_dim = {self.feature_dim}, "
                      "but default feature_map requires self.feature_dim = 2")
            embedded_data = torch.empty(embedded_shape)
            pi = math.pi/2
            sin_input = torch.sin(pi * input_data)
            cos_input = torch.cos(pi * input_data)
            embedded_data[:,:,0] = sin_input
            embedded_data[:,:,1] = cos_input


        return embedded_data

    def register_feature_map(self, feature_map):

        if feature_map is not None:
            out_shape = feature_map(torch.tensor(0)).shape
            needed_shape = torch.Size([self.feature_dim])
            if out_shape != needed_shape:
                raise ValueError("Given feature_map returns values of size "
                                f"{list(out_shape)}, but should return "
                                f"values of size {list(needed_shape)}")

        self.feature_map = feature_map

    def core_len(self):
        return self.linear_region.core_len()

    def __len__(self):
        return self.input_dim


class LinearRegion(nn.Module):
    def __init__(self, module_list, periodic_bc=False, parallel_eval=False,
                 module_states=None):
        if not isinstance(module_list, list) or module_list is []:
            raise ValueError("Input to LinearRegion must be nonempty list")
        for i, item in enumerate(module_list):
            if not isinstance(item, nn.Module):
                raise ValueError("Input items to LinearRegion must be PyTorch "
                                f"Module instances, but item {i} is not")
        super().__init__()

        self.module_list = nn.ModuleList(module_list)
        self.periodic_bc = periodic_bc
        self.parallel_eval = parallel_eval

    def forward(self, input_data):
        assert len(input_data.shape) == 3
        assert input_data.size(1) == len(self)
        periodic_bc = self.periodic_bc
        parallel_eval = self.parallel_eval
        lin_bonds = ['l', 'r']

        ind = 0
        contractable_list = []
        for module in self.module_list:
            mod_len = len(module)
            if mod_len == 1:
                mod_input = input_data[:, ind]
            else:
                mod_input = input_data[:, ind:(ind+mod_len)]
            ind += mod_len

            contractable_list.append(module(mod_input))
            if len(list(module.tensor.shape)) == 3:
                QELoss = torch.sum(torch.sum(module.tensor, dim=0),dim=0)
        if periodic_bc:
            contractable_list = ContractableList(contractable_list)
            contractable = contractable_list.reduce(parallel_eval=True)
            tensor, bond_str = contractable.tensor, contractable.bond_str
            assert all(c in bond_str for c in lin_bonds)

            in_str, out_str = "", ""
            for c in bond_str:
                if c in lin_bonds:
                    in_str += 'l'
                else:
                    in_str += c
                    out_str += c
            ein_str = in_str + "->" + out_str

            return torch.einsum(ein_str, [tensor]), QELoss

        else:
            end_items = [contractable_list[i]for i in [0, -1]]
            bond_strs = [item.bond_str for item in end_items]
            bond_inds = [bs.index(c) for (bs, c) in zip(bond_strs, lin_bonds)]
            bond_dims = [item.tensor.size(ind) for (item, ind) in
                                               zip(end_items, bond_inds)]

            end_vecs = [torch.zeros(dim) for dim in bond_dims]
            for vec in end_vecs:
                vec[0] = 1
            contractable_list.insert(0, EdgeVec(end_vecs[0], is_left_vec=True))
            contractable_list.append(EdgeVec(end_vecs[1], is_left_vec=False))

            contractable_list = ContractableList(contractable_list)
            output = contractable_list.reduce(parallel_eval=parallel_eval)

            return output.tensor, QELoss

    def core_len(self):
        return sum([module.core_len() for module in self.module_list])

    def __len__(self):
        return sum([len(module) for module in self.module_list])

class MergedLinearRegion(LinearRegion):
    def __init__(self, module_list, periodic_bc=False, parallel_eval=False,
                 cutoff=1e-10, merge_threshold=2000):
        super().__init__(module_list, periodic_bc, parallel_eval)

        self.offset = 0
        self.merge(offset=self.offset)
        self.merge(offset=(self.offset+1)%2)
        self.module_list = getattr(self, f"module_list_{self.offset}")

        self.input_counter = 0
        self.merge_threshold = merge_threshold
        self.cutoff = cutoff

    def forward(self, input_data):

        if self.input_counter >= self.merge_threshold:
            bond_list, sv_list = self.unmerge(cutoff=self.cutoff)
            self.offset = (self.offset + 1) % 2
            self.merge(offset=self.offset)
            self.input_counter -= self.merge_threshold

            self.module_list = getattr(self, f"module_list_{self.offset}")
        else:
            bond_list, sv_list = None, None
        self.input_counter += input_data.size(0)
        output = super().forward(input_data)

        if bond_list:
            return output, bond_list, sv_list
        else:
            return output

    @torch.no_grad()
    def merge(self, offset):

        assert offset in [0, 1]

        unmerged_list = self.module_list

        site_num = offset
        merged_list = []
        for core in unmerged_list:
            assert not isinstance(core, MergedInput)
            assert not isinstance(core, MergedOutput)

            if hasattr(core, 'merge'):
                merged_list.extend(core.merge(offset=site_num%2))
            else:
                merged_list.append(core)

            site_num += core.core_len()

        while True:
            mod_num, site_num = 0, 0
            combined_list = []
            while mod_num < len(merged_list) - 1:
                left_core, right_core = merged_list[mod_num: mod_num+2]
                new_core = self.combine(left_core, right_core,
                                                   merging=True)

                if new_core is None or offset != site_num % 2:
                    combined_list.append(left_core)
                    mod_num += 1
                    site_num += left_core.core_len()

                else:
                    assert new_core.core_len() == left_core.core_len() + \
                                                  right_core.core_len()
                    combined_list.append(new_core)
                    mod_num += 2
                    site_num += new_core.core_len()

                if mod_num == len(merged_list)-1:
                    combined_list.append(merged_list[mod_num])
                    mod_num += 1

            if len(combined_list) == len(merged_list):
                break
            else:
                merged_list = combined_list

        list_name = f"module_list_{offset}"
        if not hasattr(self, list_name):
            setattr(self, list_name, nn.ModuleList(merged_list))

        else:
            module_list = getattr(self, list_name)
            assert len(module_list) == len(merged_list)
            for i in range(len(module_list)):
                assert module_list[i].tensor.shape == \
                       merged_list[i].tensor.shape
                module_list[i].tensor[:] = merged_list[i].tensor

    @torch.no_grad()
    def unmerge(self, cutoff=1e-10):
        list_name = f"module_list_{self.offset}"
        merged_list = getattr(self, list_name)

        unmerged_list, bond_list, sv_list = [], [-1], [-1]
        for core in merged_list:

            if hasattr(core, 'unmerge'):
                new_cores, new_bonds, new_svs = core.unmerge(cutoff)
                unmerged_list.extend(new_cores)
                bond_list.extend(new_bonds[1:])
                sv_list.extend(new_svs[1:])
            else:
                assert not isinstance(core, InputRegion)
                unmerged_list.append(core)
                bond_list.append(-1)
                sv_list.append(-1)

        while True:
            mod_num = 0
            combined_list = []

            while mod_num < len(unmerged_list) - 1:
                left_core, right_core = unmerged_list[mod_num: mod_num+2]
                new_core = self.combine(left_core, right_core,
                                                   merging=False)

                if new_core is None:
                    combined_list.append(left_core)
                    mod_num += 1

                else:
                    combined_list.append(new_core)
                    mod_num += 2

                if mod_num == len(unmerged_list)-1:
                    combined_list.append(unmerged_list[mod_num])
                    mod_num += 1

            if len(combined_list) == len(unmerged_list):
                break
            else:
                unmerged_list = combined_list

        log_norms = []
        for core in unmerged_list:
            log_norms.append([torch.log(norm) for norm in core.get_norm()])
        log_scale = sum([sum(ns) for ns in log_norms])
        log_scale /= sum([len(ns) for ns in log_norms])

        scales = [[torch.exp(log_scale-n) for n in ns] for ns in log_norms]
        for core, these_scales in zip(unmerged_list, scales):
            core.rescale_norm(these_scales)

        self.module_list = nn.ModuleList(unmerged_list)
        return bond_list, sv_list

    def combine(self, left_core, right_core, merging):

        if merging and ((isinstance(left_core, OutputSite) and
                         isinstance(right_core, InputSite)) or
                            (isinstance(left_core, InputSite) and
                            isinstance(right_core, OutputSite))):

            left_site = isinstance(left_core, InputSite)
            if left_site:
                new_tensor = torch.einsum('lui,our->olri', [left_core.tensor,
                                                            right_core.tensor])
            else:
                new_tensor = torch.einsum('olu,uri->olri', [left_core.tensor,
                                                            right_core.tensor])
            return MergedOutput(new_tensor, left_output=(not left_site))

        elif not merging and ((isinstance(left_core, InputRegion) and
                               isinstance(right_core, InputSite)) or
                                    (isinstance(left_core, InputSite) and
                                    isinstance(right_core, InputRegion))):

            left_site = isinstance(left_core, InputSite)
            if left_site:
                left_tensor = left_core.tensor.unsqueeze(0)
                right_tensor = right_core.tensor
            else:
                left_tensor = left_core.tensor
                right_tensor = right_core.tensor.unsqueeze(0)

            assert left_tensor.shape[1:] == right_tensor.shape[1:]
            new_tensor = torch.cat([left_tensor, right_tensor])

            return InputRegion(new_tensor)

        else:
            return None

    def core_len(self):
        return sum([module.core_len() for module in self.module_list])

    def __len__(self):
        return sum([len(module) for module in self.module_list])

class InputRegion(nn.Module):
    def __init__(self, tensor, use_bias=True, fixed_bias=True, bias_mat=None,
                 ephemeral=False):
        super().__init__()

        assert len(tensor.shape) == 4
        assert tensor.size(1) == tensor.size(2)
        bond_dim = tensor.size(1)

        if use_bias:
            assert bias_mat is None or isinstance(bias_mat, torch.Tensor)
            bias_mat = torch.eye(bond_dim).unsqueeze(0) if bias_mat is None \
                       else bias_mat

            bias_modes = len(list(bias_mat.shape))
            assert bias_modes in [2, 3]
            if bias_modes == 2:
                bias_mat = bias_mat.unsqueeze(0)

        if ephemeral:
            self.register_buffer(name='tensor', tensor=tensor.contiguous())
            self.register_buffer(name='bias_mat', tensor=bias_mat)
        else:
            self.register_parameter(name='tensor', 
                                    param=nn.Parameter(tensor.contiguous()))
            if fixed_bias:
                self.register_buffer(name='bias_mat', tensor=bias_mat)
            else:
                self.register_parameter(name='bias_mat', 
                                        param=nn.Parameter(bias_mat))

        self.use_bias = use_bias
        self.fixed_bias = fixed_bias
        
    def forward(self, input_data):
        tensor = self.tensor.cuda()
        input_data = input_data.cuda()
        assert len(input_data.shape) == 3
        assert input_data.size(1) == len(self)
        assert input_data.size(2) == tensor.size(3)

        mats = torch.einsum('slri,bsi->bslr', [tensor, input_data])

        if self.use_bias:
            bond_dim = tensor.size(1)
            bias_mat = self.bias_mat.unsqueeze(0)
            mats = mats + bias_mat.expand_as(mats)
        return MatRegion(mats)

    def merge(self, offset):
        assert offset in [0, 1]
        num_sites = self.core_len()
        parity = num_sites % 2

        if num_sites == 0:
            return [None]

        if (offset, parity) == (1, 1):
            out_list = [self[0], self[1:].merge(offset=0)[0]]
        elif (offset, parity) == (1, 0):
            out_list = [self[0], self[1:-1].merge(offset=0)[0], self[-1]]
        elif (offset, parity) == (0, 1):
            out_list = [self[:-1].merge(offset=0)[0], self[-1]]

        else:
            tensor = self.tensor
            even_cores, odd_cores = tensor[0::2], tensor[1::2]
            assert len(even_cores) == len(odd_cores)

            merged_cores = torch.einsum('slui,surj->slrij', [even_cores,
                                                             odd_cores])
            out_list = [MergedInput(merged_cores)]

        return [x for x in out_list if x is not None]

    def __getitem__(self, key):
        assert isinstance(key, int) or isinstance(key, slice)

        if isinstance(key, slice):
            return InputRegion(self.tensor[key])
        else:
            return InputSite(self.tensor[key])

    def get_norm(self):
        return [torch.norm(core) for core in self.tensor]

    @torch.no_grad()
    def rescale_norm(self, scale_list):
        assert len(scale_list) == len(self.tensor)

        for core, scale in zip(self.tensor, scale_list):
            core *= scale

    def core_len(self):
        return len(self)

    def __len__(self):
        return self.tensor.size(0)

class MergedInput(nn.Module):
    def __init__(self, tensor):
        bond_str = 'slrij'
        shape = tensor.shape
        
        assert len(shape) == 5
        assert shape[1] == shape[2]
        assert shape[3] == shape[4]

        super().__init__()

        self.register_parameter(name='tensor', param=nn.Parameter(tensor.contiguous()))
    def forward(self, input_data):
        tensor = self.tensor.cuda()
        input_data = input_data.cuda()
        assert len(input_data.shape) == 3
        assert input_data.size(1) == len(self)
        assert input_data.size(2) == tensor.size(3)
        assert input_data.size(1) % 2 == 0
        inputs = [input_data[:, 0::2], input_data[:, 1::2]]
        tensor = torch.einsum('slrij,bsj->bslri', [tensor, inputs[1]])
        mats = torch.einsum('bslri,bsi->bslr', [tensor, inputs[0]])

        return MatRegion(mats)

    def unmerge(self, cutoff=1e-10):
        bond_str = 'slrij'
        tensor = self.tensor
        svd_string = 'lrij->lui,urj'
        max_D = tensor.size(1)

        core_list, bond_list, sv_list = [], [-1], [-1]
        for merged_core in tensor:
            sv_vec = torch.empty(max_D)
            left_core, right_core, bond_dim = svd_flex(merged_core, svd_string,
                                              max_D, cutoff, sv_vec=sv_vec)

            core_list += [left_core, right_core]
            bond_list += [bond_dim, -1]
            sv_list += [sv_vec, -1]

        tensor = torch.stack(core_list)
        return [InputRegion(tensor)], bond_list, sv_list

    def get_norm(self):
        return [torch.norm(core) for core in self.tensor]

    @torch.no_grad()
    def rescale_norm(self, scale_list):
        assert len(scale_list) == len(self.tensor)

        for core, scale in zip(self.tensor, scale_list):
            core *= scale

    def core_len(self):
        return len(self)

    def __len__(self):
        return 2 * self.tensor.size(0)

class InputSite(nn.Module):
    def __init__(self, tensor):
        super().__init__()
        self.register_parameter(name='tensor', 
                                param=nn.Parameter(tensor.contiguous()))

    def forward(self, input_data):
        tensor = self.tensor.cuda()
        input_data = input_data.cuda()
        assert len(input_data.shape) == 2
        assert input_data.size(1) == tensor.size(2)
        mat = torch.einsum('lri,bi->blr', [tensor, input_data])

        return SingleMat(mat)

    def get_norm(self):
        return [torch.norm(self.tensor)]

    @torch.no_grad()
    def rescale_norm(self, scale):
        if isinstance(scale, list):
            assert len(scale) == 1
            scale = scale[0]

        self.tensor *= scale

    def core_len(self):
        return 1

    def __len__(self):
        return 1

class OutputSite(nn.Module):

    def __init__(self, tensor):
        super().__init__()
        self.register_parameter(name='tensor', 
                                param=nn.Parameter(tensor.contiguous()))

    def forward(self, input_data):
        return OutputCore(self.tensor)

    def get_norm(self):
        return [torch.norm(self.tensor)]

    @torch.no_grad()
    def rescale_norm(self, scale):
        if isinstance(scale, list):
            assert len(scale) == 1
            scale = scale[0]

        self.tensor *= scale

    def core_len(self):
        return 1

    def __len__(self):
        return 0

class MergedOutput(nn.Module):

    def __init__(self, tensor, left_output):
        bond_str = 'olri'
        assert len(tensor.shape) == 4
        super().__init__()

        self.register_parameter(name='tensor', 
                                param=nn.Parameter(tensor.contiguous()))
        self.left_output = left_output

    def forward(self, input_data):
        tensor = self.tensor.cuda()
        input_data = input_data.cuda()
        assert len(input_data.shape) == 2
        assert input_data.size(1) == tensor.size(3)

        tensor = torch.einsum('olri,bi->bolr', [tensor, input_data])

        return OutputCore(tensor)

    def unmerge(self, cutoff=1e-10):
        bond_str = 'olri'
        tensor = self.tensor
        left_output = self.left_output
        if left_output:
            svd_string = 'olri->olu,uri'
            max_D = tensor.size(2)
            sv_vec = torch.empty(max_D)

            output_core, input_core, bond_dim = svd_flex(tensor, svd_string,
                                                max_D, cutoff, sv_vec=sv_vec)
            return ([OutputSite(output_core), InputSite(input_core)],
                    [-1, bond_dim, -1], [-1, sv_vec, -1])

        else:
            svd_string = 'olri->our,lui'
            max_D = tensor.size(1)
            sv_vec = torch.empty(max_D)

            output_core, input_core, bond_dim = svd_flex(tensor, svd_string,
                                                max_D, cutoff, sv_vec=sv_vec)
            return ([InputSite(input_core), OutputSite(output_core)],
                    [-1, bond_dim, -1], [-1, sv_vec, -1])

    def get_norm(self):
        return [torch.norm(self.tensor)]

    @torch.no_grad()
    def rescale_norm(self, scale):
        if isinstance(scale, list):
            assert len(scale) == 1
            scale = scale[0]

        self.tensor *= scale

    def core_len(self):
        return 2

    def __len__(self):
        return 1

class InitialVector(nn.Module):
    def __init__(self, bond_dim, fill_dim=None, fixed_vec=True, 
                 is_left_vec=True):
        super().__init__()

        vec = torch.ones(bond_dim)
        if fill_dim is not None:
            assert fill_dim >= 0 and fill_dim <= bond_dim
            vec[fill_dim:] = 0

        if fixed_vec:
            vec.requires_grad = False
            self.register_buffer(name='vec', tensor=vec)
        else:
            vec.requires_grad = True
            self.register_parameter(name='vec', param=nn.Parameter(vec))
        
        assert isinstance(is_left_vec, bool)
        self.is_left_vec = is_left_vec

    def forward(self):
        return EdgeVec(self.vec, self.is_left_vec)

    def core_len(self):
        return 1

    def __len__(self):
        return 0

class TerminalOutput(nn.Module):
    def __init__(self, bond_dim, output_dim, fixed_mat=False,
                 is_left_mat=False):
        super().__init__()

        if fixed_mat and output_dim > bond_dim:
            raise ValueError("With fixed_mat=True, TerminalOutput currently "
                             "only supports initialization for bond_dim >= "
                             "output_dim, but here bond_dim="
                            f"{bond_dim} and output_dim={output_dim}")

        mat = torch.eye(bond_dim, output_dim)
        if fixed_mat:
            mat.requires_grad = False
            self.register_buffer(name='mat', tensor=mat)
        else:
            mat = mat + torch.randn_like(mat) / bond_dim

            mat.requires_grad = True
            self.register_parameter(name='mat', param=nn.Parameter(mat))

        assert isinstance(is_left_mat, bool)
        self.is_left_mat = is_left_mat

    def forward(self):
        return OutputMat(self.mat, self.is_left_mat)

    def core_len(self):
        return 1

    def __len__(self):
        return 0
