from .comlib import *
from . import compute_grad
from .tensor_tuple import tensorTupleToDevice,ModelPara,TensorTuple
from .. import util,task

def state_dict_to_form_dict(state_dict):
    form_dict={}
    for key in state_dict:
        form_dict[key]=state_dict[key].size()
    return form_dict

def model_to_form_dict(model: torch.nn.Module):
    return state_dict_to_form_dict(model.state_dict())

class ModelSetup():
    def __init__(self,device,model,ini_weight:task.Initialize,trans_cpu=True,initialize=True):
        '''
        example: Model_setup(device,mnist.SimpleNN(), nn.CrossEntropyLoss(reduction="none"))
        '''
        self.device=device
        self.trans_cpu=trans_cpu
        self.model=model.to(device)
        self.form_dict = model_to_form_dict(self.model)
        self.param_name_tuple = tuple(name for name, _ in self.model.named_parameters())
        self.parameter_tuple=tuple(self.model.parameters())

        self.ini_weight=ini_weight
        if initialize:
            self.initialize_model()

        # self.unreduced_training_criterion=unreduced_training_criterion
    def to(self,device,trans_cpu=True):
        if trans_cpu:
            self.model=self.model.cpu()
        self.model=self.model.to(device)
        self.parameter_tuple=tuple(self.model.parameters())
        self.device=device

    def get_para_num(self):
        num = 0
        for key in self.form_dict:
            temp_size=self.form_dict[key]
            num_param = torch.prod(torch.LongTensor(list(temp_size)))
            num=num+num_param
        return num

    def set_model(self,new_model):
        self.model=new_model
        self.parameter_tuple=tuple(self.model.parameters())

    def initialize_model(self):
        self.model.apply(self.ini_weight.initialize_weights)
        self.parameter_tuple=tuple(self.model.parameters())

    def calcuJacobian(self,input_args,criterion,output_args):
        self.model.train()
        input_args=tensorTupleToDevice(input_args,self.device,self.trans_cpu)
        output_args=tensorTupleToDevice(output_args,self.device,self.trans_cpu)

        def func(*parameter_tuple):
            parameter_dict=dict(zip(self.param_name_tuple,
                                    parameter_tuple))
            
            out=torch.func.functional_call(self.model,
                                            parameter_dict, 
                                            input_args)
            # print("out:", out)
            # print("*output_args:", *output_args)
            loss=criterion(out, *output_args)
            # print("loss:", loss)
            return loss

        jacobian=compute_grad.compute_jacobian(func,self.parameter_tuple)
        
        return jacobian
    
    def calcuDatasetJacobian(self,funcOut,dataset,batch_size)->TensorTuple:
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)
        funcGetJac=util.composefunc(TensorTuple,self.calcuJacobian)
        func=lambda input_args,output_args:funcGetJac(input_args,funcOut,output_args)
        jacobi_tuple=util.getDataloaderAvg(dataloader,func,'sum')
        return jacobi_tuple.get_sum()
    
    def calcuLoss(self,input_args, criterion,output_args):
        '''
        input_args:data
        output_args:target
        '''
        self.model.eval()
        input_args=tensorTupleToDevice(input_args,self.device,self.trans_cpu)
        output_args=tensorTupleToDevice(output_args,self.device,self.trans_cpu)

        with torch.no_grad():
            parameter_dict=dict(zip(self.param_name_tuple,self.parameter_tuple))
            
            out=torch.func.functional_call(self.model, 
                                           parameter_dict, 
                                           input_args)
            loss=criterion(out, *output_args)
        return loss
    
    def getDatasetCriterion(self,dataset,criterion,batch_size,
              reduction:Literal['mean','sum']='mean'):
        '''
        criterion reduction should be mean
        '''
        dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=False)

        val=util.MovingAvg()
        for batch_idx, (input_args, output_args) in enumerate(dataloader):
            temp=self.calcuLoss(input_args,criterion,output_args)
            val.update(temp,len(input_args[0]))
        if reduction=='mean':
            return val.mean
        if reduction=='sum':
            return val.get_sum()

    def gettWeightNorm(self):
        return ModelPara(self.parameter_tuple).norm()

class Argset(Dataset):
    def __init__(self,dataset):
        super().__init__()
        self.dataset=dataset
    def __getitem__(self, item):
        data,target=self.dataset[item]
        input_args=(data,)
        output_args=(target,)
        return input_args,output_args

    def __len__(self):
        return len(self.dataset)
    
# class ModelLoss():
#     def __init__(self,modelSetup: ModelSetup,criterion: function):
#         '''
#         criterion: nn.CrossEntropyLoss() (out,target)
#         '''
#         self.modelSetup=modelSetup
#         self.device=modelSetup.device
#         self.criterion=criterion
    
            

    

