import torch
import numpy as np
from .comlib import *
from . import compute_grad, convert_vec_tuple

def tensorTupleToDevice(tup,device,trans_cpu=True):
    if trans_cpu:
        new_tup = tuple(t.cpu().to(device) for t in tup)
    else:
        new_tup = tuple(t.to(device) for t in tup)
    return new_tup

class TensorTuple:
    def __init__(self, tensors):
        """
        初始化 TensorTuple 对象。
        
        参数:
            tensors: 一个 tuple of tensors
        """
        self.tensors = tensors
        self.device=tensors[0].device

    def __add__(self, other):
        """
        重载 + 操作符，实现两个 TensorTuple 的加法。
        
        参数:
            other: 另一个 TensorTuple 对象
        
        返回:
            一个新的 TensorTuple 对象，包含逐元素相加的结果
        """
        if not isinstance(other, TensorTuple):
            raise TypeError("Both operands must be TensorTuple instances.")
        
        if len(self.tensors) != len(other.tensors):
            raise ValueError("Both TensorTuple instances must have the same number of tensors.")
        
        result_tensors = tuple(a + b for a, b in zip(self.tensors, other.tensors))
        return self.__class__(result_tensors)

    def __sub__(self, other):
        """
        重载 - 操作符，实现两个 TensorTuple 的减法。
        
        参数:
            other: 另一个 TensorTuple 对象
        
        返回:
            一个新的 TensorTuple 对象，包含逐元素相减的结果
        """
        if not isinstance(other, TensorTuple):
            raise TypeError("Both operands must be TensorTuple instances.")
        
        if len(self.tensors) != len(other.tensors):
            raise ValueError("Both TensorTuple instances must have the same number of tensors.")
        
        result_tensors = tuple(a - b for a, b in zip(self.tensors, other.tensors))
        return self.__class__(result_tensors)

    def __repr__(self):
        """
        返回 TensorTuple 对象的字符串表示。
        """
        return f"TensorTuple({self.tensors})"

    def __mul__(self, scalar):
        """
        重载 * 操作符，实现 TensorTuple 与标量的乘法。
        
        参数:
            scalar: 一个标量（int 或 float）
        
        返回:
            一个新的 TensorTuple 对象，包含逐元素乘以标量的结果
        """
        
        result_tensors = tuple(t * scalar for t in self.tensors)
        return self.__class__(result_tensors)
    
    def __rmul__(self, scalar):
        """
        重载 * 操作符，实现标量与 TensorTuple 的乘法。
        
        参数:
            scalar: 一个标量（int 或 float）
        
        返回:
            一个新的 TensorTuple 对象，包含逐元素乘以标量的结果
        """
        return self.__mul__(scalar)
    
    def __truediv__(self, scalar):
        result_tensors = tuple(t / scalar for t in self.tensors)
        return self.__class__(result_tensors)
    
    def __eq__(self, other):
        if len(self.tensors) != len(other.tensors):
            return False
        for tensor1, tensor2 in zip(self.tensors, other.tensors):
            if not torch.equal(tensor1, tensor2):
                return False
        return True
    # def to_(self,device,trans_cpu=True):
    #     for tensor1 in self.tensors:
    #         if trans_cpu:
    #             tensor1=tensor1.cpu()
    #         tensor1=tensor1.to(device)
    def to(self,device,trans_cpu=True):
        new_tup=tensorTupleToDevice(self.tensors,device,trans_cpu)
        return TensorTuple(new_tup)

    # def sameSize(self,other):
    #     if len(self.tensors) != len(other.tensors):
    #         return False
    #     for tensor1, tensor2 in zip(self.tensors, other.tensors):


class ModelPara(TensorTuple):
    def __init__(self, tensors):
        super().__init__(tensors)

    def flatten(self):
        return convert_vec_tuple.flatten_to_vec(self.tensors)

    def norm(self):
        result = 0
        for tensor in self.tensors:
            result = result+torch.sum(tensor * tensor)
        return torch.sqrt(result)
    
    def dot(self,other):
        # 验证两个 tuple 的长度是否相同
        if len(self.tensors) != len(other.tensors):
            raise ValueError("The tuples must have the same length.")
        # 计算点积
        result = 0
        for tensor1, tensor2 in zip(self.tensors, other.tensors):
            # 验证每个位置上的 tensors 是否具有相同的形状
            if tensor1.shape != tensor2.shape:
                raise ValueError("Tensors at each position must have the same shape.")
            
            # 计算点积
            result = result+torch.sum(tensor1 * tensor2)
        
        return result
    
# class ModelNamedPara(ModelPara):
#     def __init__(self, tensors, model):
#         super().__init__(tensors)
#         self.form_dict=...

class ModelParaS(TensorTuple):
    def __init__(self, tensors):
        super().__init__(tensors)
        

    def num(self):
        return len(self.tensors[0])

    def index_dim0(self,i):
        new_tuple=[]
        for tens in self.tensors:
            new_tuple.append(tens[i])
        new_tuple=tuple(new_tuple)
        return TensorTuple(new_tuple)
    
    def index_subset_dim0(self,i):
        new_tuple=[]
        for tens in self.tensors:
            new_tuple.append(tens[i])
        new_tuple=tuple(new_tuple)
        return ModelParaS(new_tuple)

    def norms(self):
        result = torch.zeros(self.num(),device=self.device)
        for tensor in self.tensors:
            result = result+torch.norm(tensor,dim=tuple(range(1, tensor.dim())))**2
        return torch.sqrt(result)

    
    def mean(self):
        meanModelPara=[]
        for t in self.tensors:
            t:torch.tensor
            meanModelPara.append(torch.mean(t,dim=0))
        return ModelPara(tuple(meanModelPara))

    def mult_coeffs(self,coeffs:torch.tensor):
        '''
        in-place
        '''
        if coeffs.size(0) != self.num():
            raise ValueError("coeffs的长度必须与num相同")
    
        for t in self.tensors:
            # 将B扩展到与A的第0维兼容的形状
            coeffs_expanded = coeffs.view(-1, *([1] * (t.dim() - 1)))  # 使用view方法扩展维度
            
            # 进行逐元素乘法
            t.mul_(coeffs_expanded)
            
        # return result
    
    
    def divide_coeffs(self,coeffs:torch.tensor):
        if coeffs.size(0) != self.num():
            raise ValueError("coeffs的长度必须与num相同")
    
        for t in self.tensors:
            t:torch.tensor
            # 将B扩展到与A的第0维兼容的形状
            coeffs_expanded = coeffs.view(-1, *([1] * (t.dim() - 1)))  # 使用view方法扩展维度
            
            # 进行逐元素乘法
            t.div_(coeffs_expanded)
            
        # return result
    
    def flatten(self):
        return util.flatten(self.tensors,1)
    

    # def jacobi_tuple_to_vec(jacobi_tuple,first_dim_len):
    #     vec = []
    #     for t in jacobi_tuple:
    #         if first_dim_len!=0:
    #             vec.append(t.view(first_dim_len,-1))
    #         else:
    #             vec.append(t.view(-1))
    #     return torch.cat(vec,dim=1)