import torch
from .projector import CudaProjector, ProjectionType, BasicProjector, NoOpProjector
from .grad_calculator import count_parameters, grad_calculator, out_to_loss_grad_calculator, count_parameters_list
from tqdm import tqdm


class DropTRAK:
    def __init__(self,
                 model,
                 model_checkpoints,
                 train_loader,
                 test_loader,
                 model_output_class,
                 device,
                 dropout=False,
                 dropout_only_Q=False,
                 two_stage=False,
                 two_stage_rho=0,
                 train_sample_number=None,
                 test_sample_number=None):
        '''
        :param model: a nn.module instance, no need to load any checkpoint
        :param model_checkpoints: a list of checkpoint path, the length of this list indicates the ensemble model number
        :param train_loader: train samples in a data loader
        :param train_loader: test samples in a data loader
        :param model_output_class: a class definition inheriting BaseModelOutputClass
        :param device: the device running
        :param dropout: if drouput is opened
        :param two_stage: if we use two stage regression with parameter two_stage_rho
        :param two_stage_rho: the off-diagnal item value within the block diagnal matrix
        :param train_sample_number: the number of train dataset
        :param test_sample_number: the number of test dataset
        '''
        self.model = model
        self.model_checkpoints = model_checkpoints
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.model_output_class = model_output_class
        self.device = device
        self.dropout = dropout
        self.dropout_only_Q = dropout_only_Q
        self.two_stage = two_stage
        self.two_stage_rho = two_stage_rho
        self.train_sample_number = train_sample_number
        self.test_sample_number = test_sample_number

        # model_checkpoints format checking
        if self.two_stage:
            assert isinstance(self.model_checkpoints[0], list), "items in model_checkpoints should be lists if you state two_stage=True"
            self.two_stage_v = torch.eye(len(self.model_checkpoints[0]))
            self.two_stage_v[self.two_stage_v==0] = self.two_stage_rho
            print("two stage V = ", self.two_stage_v)
            self.two_stage_v_1 = torch.block_diag(*[torch.linalg.inv(self.two_stage_v) for _ in range(self.train_sample_number)]).to(self.device)
        # else:
        #     assert isinstance(self.model_checkpoints[0], str), "items in model_checkpoints should be str if you state two_stage=False"

    def _score_two_stage(self):
        running_avg_x_invXTv_1X_XTv_1 = 0  # using broadcast
        running_counter_x_invXTv_1X_XTv_1 = 0
        running_avg_Q = 0  # using broadcast
        running_counter_Q = 0

        # outer loop, for different trained model's ensemble
        for checkpoint_id_list, checkpoint_file_list in enumerate(tqdm(self.model_checkpoints)):
            # inner loop for checkpoints within a training process
            all_grads_p_list = []
            all_grads_test_p_list = []
            Q_list = []
            for checkpoint_id, checkpoint_file in enumerate(tqdm(checkpoint_file_list)):
                self.model.load_state_dict(torch.load(checkpoint_file))
                self.model.eval()
                if self.dropout:
                    self.model.enable_dropout()
                print(self.model)
                print((count_parameters(self.model)))
                parameters = list(self.model.parameters())
                normalize_factor = torch.sqrt(torch.tensor(count_parameters(self.model), dtype=torch.float32))

                # projection of the grads
                projector = CudaProjector(grad_dim=count_parameters(self.model), proj_dim=2048, seed=0,
                                        proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)

                # Go through the training loader to get grads
                all_grads_p = grad_calculator(data_loader=self.train_loader, model=self.model, parameters=parameters,
                                            func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                            device=self.device, projector=projector, checkpoint_id=checkpoint_id_list)  # use checkpoint_id_list to make it work
                out_to_loss_grads = out_to_loss_grad_calculator(data_loader=self.train_loader, model=self.model,
                                                                func=self.model_output_class.get_out_to_loss_grad)
                all_grads_test_p = grad_calculator(data_loader=self.test_loader, model=self.model, parameters=parameters,
                                                func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                                device=self.device, projector=projector, checkpoint_id=checkpoint_id_list)  # use checkpoint_id_list to make it work

                all_grads_p_list.append(all_grads_p)  # [(N, K), ...]
                Q_list.append(out_to_loss_grads)   # [(N, N), ...]
                all_grads_test_p_list.append(all_grads_test_p)  # [(N_te, K), ...]

            X = torch.cat(all_grads_p_list, dim=0)  # (MN, K)
            Q = torch.cat(Q_list, dim=0)  # (MN, N)
            x = torch.cat(all_grads_test_p_list, dim=0)  # (MN_te, K)

            # reorder
            reorder_index = []
            for train_i in range(self.train_sample_number):
                reorder_index += [idx for idx in range(train_i, X.shape[0], self.train_sample_number)]
            X = X[reorder_index]
            Q = Q[reorder_index]

            x_invXTv_1X_XTv_1 = x @ torch.linalg.inv(X.T @ self.two_stage_v_1 @ X) @ X.T @ self.two_stage_v_1  # (MN_te, MN)
            x_invXTv_1X_XTv_1 = torch.mean(
                    torch.stack([x_invXTv_1X_XTv_1[start:start+self.test_sample_number, :] for start in range(0, x_invXTv_1X_XTv_1.shape[0], self.test_sample_number)], dim=0),
                    dim=0
                )

            # Use running avg to reduce mem usage
            running_avg_x_invXTv_1X_XTv_1 = running_avg_x_invXTv_1X_XTv_1 * running_counter_x_invXTv_1X_XTv_1 + x_invXTv_1X_XTv_1.cpu().clone().detach()
            running_avg_Q = running_avg_Q * running_counter_Q + Q.cpu().clone().detach()

            running_counter_x_invXTv_1X_XTv_1 += 1
            running_counter_Q += 1

            running_avg_x_invXTv_1X_XTv_1 /= running_counter_x_invXTv_1X_XTv_1
            running_avg_Q /= running_counter_Q

        score = running_avg_x_invXTv_1X_XTv_1 @ running_avg_Q
        return score

    def _score(self):
        # each item represent a checkpoint
        # x_invXTX_XT_list = []
        # Q_list = []

        # running_avg_x = 0
        # running_counter_x = 0
        # running_avg_XTX = 0
        # running_counter_XTX = 0
        # running_avg_XTQ = 0
        # running_counter_XTQ = 0

        running_avg_x_invXTX_XT = 0  # using broadcast
        running_counter_x_invXTX_XT = 0
        running_avg_Q = 0  # using broadcast
        running_counter_Q = 0

        # saved_grad_train = None
        # saved_grad_test = None
        for checkpoint_id, checkpoint_file in enumerate(tqdm(self.model_checkpoints)):
            self.model.load_state_dict(torch.load(checkpoint_file))
            self.model.eval()
            if self.dropout:
                self.model.enable_dropout()
            print(self.model)
            print((count_parameters(self.model)))
            parameters = list(self.model.parameters())
            normalize_factor = torch.sqrt(torch.tensor(count_parameters(self.model), dtype=torch.float32))

            # if self.num_last_parameter >= 0 and saved_grad_train is None:
            #     self.model.disable_dropout()
            #     projector_no_op = NoOpProjector(grad_dim=count_parameters(self.model), proj_dim=2048, seed=0,
            #                           proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)
            #     saved_grad_train = grad_calculator(data_loader=self.train_loader, model=self.model, parameters=parameters[:-self.num_last_parameter-1],
            #                                  func=self.model_output_class.model_output, normalize_factor=normalize_factor,
            #                                  device=self.device, projector=projector_no_op, checkpoint_id=checkpoint_id)
            #     print("saved_grad.shape", saved_grad_train.shape)  # (5000, xxx)

            # projection of the grads
            projector = CudaProjector(grad_dim=count_parameters(self.model), proj_dim=2048, seed=0,
                                      proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)

            # Go through the training loader to get grads
            all_grads_p = grad_calculator(data_loader=self.train_loader, model=self.model, parameters=parameters,
                                        func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                        device=self.device, projector=projector, checkpoint_id=checkpoint_id)
            out_to_loss_grads = out_to_loss_grad_calculator(data_loader=self.train_loader, model=self.model,
                                                            func=self.model_output_class.get_out_to_loss_grad)
            all_grads_test_p = grad_calculator(data_loader=self.test_loader, model=self.model, parameters=parameters,
                                             func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                             device=self.device, projector=projector, checkpoint_id=checkpoint_id)

            x_invXTX_XT = all_grads_test_p @ torch.linalg.inv(all_grads_p.T @ all_grads_p) @ all_grads_p.T
            Q = out_to_loss_grads

            # Use running avg to reduce mem usage
            running_avg_x_invXTX_XT = running_avg_x_invXTX_XT * running_counter_x_invXTX_XT + x_invXTX_XT.cpu().clone().detach()
            running_avg_Q = running_avg_Q * running_counter_Q + Q.cpu().clone().detach()

            running_counter_x_invXTX_XT += 1
            running_counter_Q += 1

            running_avg_x_invXTX_XT /= running_counter_x_invXTX_XT
            running_avg_Q /= running_counter_Q

            # # TRAK formula（regrouping average)
            # x = all_grads_test_p
            # XTX = all_grads_p.T @ all_grads_p
            # XTQ = all_grads_p.T @ out_to_loss_grads

            # # Use running avg to reduce mem usage
            # running_avg_x = running_avg_x * running_counter_x + x.cpu().clone().detach()
            # running_avg_XTX = running_avg_XTX * running_counter_XTX + XTX.cpu().clone().detach()
            # running_avg_XTQ = running_avg_XTQ * running_counter_XTQ + XTQ.cpu().clone().detach()

            # running_counter_x += 1
            # running_counter_XTX += 1
            # running_counter_XTQ += 1

            # running_avg_x /= running_counter_x
            # running_avg_XTX /= running_counter_XTX
            # running_avg_XTQ /= running_counter_XTQ

        # score = running_avg_x @ torch.linalg.inv(running_avg_XTX) @ running_avg_XTQ
        score = running_avg_x_invXTX_XT @ running_avg_Q
        return score

    def _score_only_Q(self):
        running_avg_x_invXTX_XT = 0  # using broadcast
        running_counter_x_invXTX_XT = 0
        running_avg_Q = 0  # using broadcast
        running_counter_Q = 0

        for checkpoint_id_list, checkpoint_file_list in enumerate(tqdm(self.model_checkpoints)):
            self.model.load_state_dict(torch.load(checkpoint_file_list[0]))
            self.model.eval()
            if self.dropout:
                self.model.enable_dropout()
            print(self.model)
            print((count_parameters(self.model)))
            parameters = list(self.model.parameters())
            normalize_factor = torch.sqrt(torch.tensor(count_parameters(self.model), dtype=torch.float32))

            # projection of the grads
            projector = CudaProjector(grad_dim=count_parameters(self.model), proj_dim=2048, seed=0,
                                    proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)

            # Go through the training loader to get grads
            all_grads_p = grad_calculator(data_loader=self.train_loader, model=self.model, parameters=parameters,
                                        func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                        device=self.device, projector=projector, checkpoint_id=[checkpoint_id_list * 100 + i for i in range(len(checkpoint_file_list))],
                                        disable_dropout=True)
            out_to_loss_grads = []
            for _ in range(len(checkpoint_file_list)):
                out_to_loss_grads.append(out_to_loss_grad_calculator(data_loader=self.train_loader, model=self.model,
                                                                func=self.model_output_class.get_out_to_loss_grad))
            all_grads_test_p = grad_calculator(data_loader=self.test_loader, model=self.model, parameters=parameters,
                                            func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                            device=self.device, projector=projector, checkpoint_id=[checkpoint_id_list * 100 + i for i in range(len(checkpoint_file_list))],
                                            disable_dropout=True)

            for i in range(len(checkpoint_file_list)):
                x_invXTX_XT = all_grads_test_p[i] @ torch.linalg.inv(all_grads_p[i].T @ all_grads_p[i]) @ all_grads_p[i].T
                Q = out_to_loss_grads[i]

                # Use running avg to reduce mem usage
                running_avg_x_invXTX_XT = running_avg_x_invXTX_XT * running_counter_x_invXTX_XT + x_invXTX_XT.cpu().clone().detach()
                running_avg_Q = running_avg_Q * running_counter_Q + Q.cpu().clone().detach()

                running_counter_x_invXTX_XT += 1
                running_counter_Q += 1

                running_avg_x_invXTX_XT /= running_counter_x_invXTX_XT
                running_avg_Q /= running_counter_Q

        score = running_avg_x_invXTX_XT @ running_avg_Q
        return score

    def score(self):
        '''
        :return: a tensor with shape (number of test data, number of train data)
        '''
        if self.two_stage:
            return self._score_two_stage()
        if self.dropout_only_Q:
            return self._score_only_Q()
        return self._score()
