import torch
from .projector import CudaProjector, ProjectionType, BasicProjector
from .grad_calculator import count_parameters, grad_calculator, out_to_loss_grad_calculator,grad_calculator_list
from tqdm import tqdm
import time

class DropTRAK:
    def __init__(self,
                 model,
                 model_checkpoints,
                 train_loader,
                 test_loader,
                 model_output_class,
                 device,
                 independent_num,
                 dropout=False,
                 two_stage=False,
                 two_stage_rho=0,
                 train_sample_number=None,
                 test_sample_number=None,
                 LoRA_finetune=False,
                 LoRA_grad_only=True):
        '''
        :param model: a nn.module instance, no need to load any checkpoint
        :param model_checkpoints: a list of checkpoint path, the length of this list indicates the ensemble model number
        :param train_loader: train samples in a data loader
        :param train_loader: test samples in a data loader
        :param model_output_class: a class definition inheriting BaseModelOutputClass
        :param device: the device running
        :param dropout: if drouput is opened
        :param two_stage: if we use two stage regression with parameter two_stage_rho
        :param two_stage_rho: the off-diagnal item value within the block diagnal matrix
        :param train_sample_number: the number of train dataset
        :param test_sample_number: the number of test dataset
        :param LoRA_finetune: if we are finetuning models using LoRA
        :param LoRA_grad_only: if we only consider the LoRA parameters to compute TRAK
        :param independent_num: the number of independent ensembles used
        '''
        self.model = model
        self.model_checkpoints = model_checkpoints
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.model_output_class = model_output_class
        self.device = device
        self.dropout = dropout
        self.two_stage = two_stage
        self.two_stage_rho = two_stage_rho
        self.train_sample_number = train_sample_number
        self.test_sample_number = test_sample_number

        self.LoRA_finetune = LoRA_finetune
        self.LoRA_grad_only = LoRA_grad_only
        self.independent_num = independent_num


        # model_checkpoints format checking
        if self.two_stage:
            assert isinstance(self.model_checkpoints[0], list), "items in model_checkpoints should be lists if you state two_stage=True"
            self.two_stage_v = torch.eye(len(self.model_checkpoints[0]))
            self.two_stage_v[self.two_stage_v==0] = self.two_stage_rho
            print("two stage V = ", self.two_stage_v)
            self.two_stage_v_1 = torch.block_diag(*[torch.linalg.inv(self.two_stage_v) for _ in range(self.train_sample_number)]).to(self.device)
        else:
            assert isinstance(self.model_checkpoints[0], str), "items in model_checkpoints should be str if you state two_stage=False"

    def _score(self):
        # each item represent a checkpoint

        running_avg_x_invXTX_XT = 0  # using broadcast
        running_counter_x_invXTX_XT = 0
        running_avg_Q = 0  # using broadcast
        running_counter_Q = 0
        before_checkpoint = time.time()
        for checkpoint_id, checkpoint_file in enumerate(self.model_checkpoints):
            self.model.load_state_dict(torch.load(checkpoint_file))
            self.model.eval()
            if self.dropout:
                self.model.enable_dropout()
            print((count_parameters(self.model)))

            # customize the params considered by TRAK
            if not self.LoRA_finetune:
                parameters = list(self.model.parameters())
                normalize_factor = torch.sqrt(torch.tensor(count_parameters(self.model), dtype=torch.float32))
            else:
                parameters = []
                for name, param in self.model.named_parameters():
                    # only consider LoRA params's grad
                    if self.LoRA_grad_only:
                        if "lora_" in name:
                            parameters.append(param)
                    # let original params' grad calculatable
                    else:
                        param.requires_grad = True
                        parameters.append(param)
                        
                total_grad = sum(p.numel() for p in parameters)
                normalize_factor = torch.sqrt(torch.tensor(total_grad, dtype=torch.float32))


            # projection of the grads
            projector = CudaProjector(grad_dim=count_parameters(self.model), proj_dim=2048, seed=0,
                                      proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)

            enter_phi = time.time()
            # Go through the training loader to get grads
            all_grads_p = grad_calculator(data_loader=self.train_loader, model=self.model, parameters=parameters,
                                        func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                        device=self.device, projector=projector, checkpoint_id=checkpoint_id)
            leave_phi = time.time()
            print("compute train phi grad time: ", leave_phi - enter_phi)
            out_to_loss_grads = out_to_loss_grad_calculator(data_loader=self.train_loader, model=self.model,
                                                            func=self.model_output_class.get_out_to_loss_grad)
            leave_out_grad = time.time()
            print("compute Q time: ", leave_out_grad - leave_phi)
            all_grads_test_p = grad_calculator(data_loader=self.test_loader, model=self.model, parameters=parameters,
                                             func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                             device=self.device, projector=projector, checkpoint_id=checkpoint_id)
            
            x_invXTX_XT = all_grads_test_p @ torch.linalg.inv(all_grads_p.T @ all_grads_p) @ all_grads_p.T
            Q = out_to_loss_grads

            # Use running avg to reduce mem usage
            running_avg_x_invXTX_XT = running_avg_x_invXTX_XT * running_counter_x_invXTX_XT + x_invXTX_XT.cpu().clone().detach()
            running_avg_Q = running_avg_Q * running_counter_Q + Q.cpu().clone().detach()

            running_counter_x_invXTX_XT += 1
            running_counter_Q += 1

            running_avg_x_invXTX_XT /= running_counter_x_invXTX_XT
            running_avg_Q /= running_counter_Q

        score = running_avg_x_invXTX_XT @ running_avg_Q
        final_out = time.time()
        print("wall clock time: ", final_out - before_checkpoint)
        return score

    def _q_drop_score(self):
        running_avg_x_invXTX_XT = 0  # using broadcast
        running_counter_x_invXTX_XT = 0
        running_avg_Q = 0  # using broadcast
        running_counter_Q = 0
        before_checkpoint = time.time()

        ensemble_num = len(self.model_checkpoints) / self.independent_num

        # compute the phi matrix
        # compute the grad only once 
        # but proj the grad matrix many times

        for checkpoint_id, checkpoint_file in enumerate(self.model_checkpoints):

            # read in checkpoint
            self.model.load_state_dict(torch.load(checkpoint_file))
            # make it eval mode to freeze params
            self.model.eval()
            # enable dropouts if needed
            if self.dropout:
                self.model.enable_dropout()
                
            # customize the params considered by TRAK
            if not self.LoRA_finetune:
                parameters = list(self.model.parameters())
            else:
                parameters = []
                for name, param in self.model.named_parameters():
                    # only consider LoRA params's grad
                    if self.LoRA_grad_only:
                        if "lora_" in name:
                            parameters.append(param)
                    # let original params' grad calculatable
                    else:
                        param.requires_grad = True
                        parameters.append(param)

            normalize_factor = torch.sqrt(torch.tensor(count_parameters(self.model), dtype=torch.float32))

            # projection of the grads
            projector = CudaProjector(grad_dim=count_parameters(self.model), proj_dim=2048, seed=0,
                                      proj_type=ProjectionType.rademacher, device="cuda", max_batch_size=8)
            


            # only if this is another independent checkpoint
            if checkpoint_id % ensemble_num == 0:
                # get all the projected phi_s here
                checkpoint_id_list = torch.arange(checkpoint_id, checkpoint_id + ensemble_num)

                # compute many projected train phi
                all_grads_p_multi_ckpt = grad_calculator_list(data_loader=self.train_loader, model=self.model, parameters=parameters,
                                        func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                        device=self.device, projector=projector, checkpoint_id_list=checkpoint_id_list)
                # compute many projected test phi
                all_grads_test_p_multi_ckpt = grad_calculator_list(data_loader=self.test_loader, model=self.model, parameters=parameters,
                                             func=self.model_output_class.model_output, normalize_factor=normalize_factor,
                                             device=self.device, projector=projector, checkpoint_id_list=checkpoint_id_list)

            # now keep on computing Q, which is dynamic (depend on test-time dropout)
            start_q = time.time()
            Q = out_to_loss_grad_calculator(data_loader=self.train_loader, model=self.model,
                                                            func=self.model_output_class.get_out_to_loss_grad)
            
            end_q = time.time()
            print("Q time: ", end_q-start_q)
            # check for current checkpoint_id, i should take which projected gradient
            position = checkpoint_id % ensemble_num

            x_invXTX_XT = all_grads_test_p_multi_ckpt[position] @ \
                torch.linalg.inv(all_grads_p_multi_ckpt[position].T @ all_grads_p_multi_ckpt[position]) @ all_grads_p_multi_ckpt[position].T

            # Use running avg to reduce mem usage
            running_avg_x_invXTX_XT = running_avg_x_invXTX_XT * running_counter_x_invXTX_XT + x_invXTX_XT.cpu().clone().detach()
            running_avg_Q = running_avg_Q * running_counter_Q + Q.cpu().clone().detach()

            running_counter_x_invXTX_XT += 1
            running_counter_Q += 1

            running_avg_x_invXTX_XT /= running_counter_x_invXTX_XT
            running_avg_Q /= running_counter_Q

        score = running_avg_x_invXTX_XT @ running_avg_Q
        final_out = time.time()
        print("wall clock time: ", final_out - before_checkpoint)

        return score


    def score(self):
        '''
        :return: a tensor with shape (number of test data, number of train data)
        '''
        return self._score()
    
    def q_drop_score(self):
        return self._q_drop_score()
    
    
