import torch
from .projector import NoOpProjector, ProjectionType, CudaProjector
from .grad_calculator import count_parameters, grad_calculator
from tqdm import tqdm


class DropGradDot:
    def __init__(self,
                 model,
                 model_checkpoints,
                 train_loader,
                 test_loader,
                 model_output_class,
                 device,
                 dropout=False,
                 cos=False):
        '''
        :param model: a nn.module instance, no need to load any checkpoint
        :param model_checkpoints: a list of checkpoint path, the length of this list indicates the ensemble model number
        :param train_loader: train samples in a data loader
        :param train_loader: test samples in a data loader
        :param model_output_class: a class definition inheriting BaseModelOutputClass
        :param device: the device running
        :param dropout: if drouput is opened
        '''
        self.model = model
        self.model_checkpoints = model_checkpoints
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.model_output_class = model_output_class
        self.device = device
        self.dropout = dropout
        self.cos = cos

    def score(self):
        '''
        :return: a tensor with shape (number of test data, number of train data)
        '''
        # each item represent a checkpoint

        XTx_list = []
        for checkpoint_id, checkpoint_file in enumerate(tqdm(self.model_checkpoints)):
            self.model.load_state_dict(torch.load(checkpoint_file))
            self.model.eval()
            if self.dropout:
                self.model.enable_dropout()
            print(self.model)
            print((count_parameters(self.model)))
            parameters = list(self.model.parameters())

            # projection of the grads
            projector = NoOpProjector(grad_dim=count_parameters(self.model), proj_dim=2048, seed=0,
                                      proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)
            # projector = CudaProjector(grad_dim=count_parameters(self.model), proj_dim=512*8, seed=0,
            #                           proj_type=ProjectionType.normal, device="cuda", max_batch_size=8)

            # Go through the training loader to get grads
            all_grads_p = grad_calculator(data_loader=self.train_loader, model=self.model, parameters=parameters,
                                        func=self.model_output_class.model_output, normalize_factor=1,
                                        device=self.device, projector=projector, checkpoint_id=checkpoint_id)
            all_grads_test_p = grad_calculator(data_loader=self.test_loader, model=self.model, parameters=parameters,
                                             func=self.model_output_class.model_output, normalize_factor=1,
                                             device=self.device, projector=projector, checkpoint_id=checkpoint_id)
            if self.cos:
                XTx = (all_grads_p / torch.norm(all_grads_p)) @ (all_grads_test_p.T / torch.norm(all_grads_test_p))
            else:
                XTx = all_grads_p @ all_grads_test_p.T
            XTx_list.append(XTx.cpu().clone().detach())

        # calculate the mean over all checkpoints
        score = torch.mean(torch.stack(XTx_list), dim=0)
        return score.T

