import logging
import numpy as np
import torch
from torch import nn
from torch.serialization import load
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils.inc_net import Naming2Learn
from models.base import BaseLearner
from utils.toolkit import get_attribute
import matplotlib.pyplot as plt
import copy

'''
The incremental_train() function serves as the main function for performing incremental training.
The core function is train(), which handles the training process for each task.
'''


def generate_weights(num_samples, mean=1, std=0.5, device='cpu',norm=True):

    weights = torch.normal(mean=mean, std=std, size=(num_samples,)).to(device)
    if norm :
        weights = weights * len(weights) / sum(weights)
    return weights

class Learner(BaseLearner):
    def __init__(self, args):
        super().__init__(args)
        self.args=args

        self._train_transformer=False
        self._network = Naming2Learn(args)
        
        self.batch_size= get_attribute(args,"batch_size", 48)
        self.num_workers= get_attribute(args,"num_workers", 8)
        self.init_lr= get_attribute(args,"init_lr", 0.01)
        self.weight_decay=  get_attribute(args,"weight_decay", 0.0005)
        self.min_lr=  get_attribute(args,"min_lr", 1e-8)
        self.frozen_layers=  get_attribute(args,"frozen_layers", None)
        self.tuned_epoch =  get_attribute(args,"tuned_epoch", 5)
        self._known_classes = 0
        self.prototype = []
        self.R = None
        #self.new_des_dict=self._get_text_des(self.args['dataset'])


 
    def after_task(self):
        self._known_classes = self._total_classes

    def incremental_train(self, data_manager):
        
        self._cur_task += 1
        self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
        
        logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))
        train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),
            source="train", mode="train",aug=1)
        self.train_dataset=train_dataset
        self.data_manager=data_manager
        self._network.to(self._device)
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" )
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

        self._network.update_task()
        self.train(self.train_loader)
        
    def backup_matrix(self):
        self.auto_cor_backup = copy.deepcopy(self.auto_cor) 
        self.crs_cor_backup = copy.deepcopy(self.crs_cor)

    def reset_matrix(self):
        self.auto_cor = copy.deepcopy(self.auto_cor_backup) 
        self.crs_cor = copy.deepcopy(self.crs_cor_backup)

    def feature_reduction(self,X,variance_threshold=0.75):


        X_centered = X - X.mean(dim=0)
        
        cov = (X_centered.T @ X_centered) / (X.shape[0] - 1)
        
        eigenvalues, eigenvectors = torch.linalg.eigh(cov)

        sorted_idx = torch.argsort(eigenvalues, descending=True)
        eigenvalues = eigenvalues[sorted_idx]
        eigenvectors = eigenvectors[:, sorted_idx]
        explained_variance_ratio = eigenvalues / eigenvalues.sum()
        cumulative_variance = torch.cumsum(explained_variance_ratio, dim=0)
        
        n_components = torch.sum(cumulative_variance < variance_threshold).item() + 1
        components = eigenvectors[:, :n_components]

        print(components.shape)
        return components.shape[1], components


        
    def init_adapter2 (self,dim):

        # Initialize the classifier for Pseudo label updating
        if self._cur_task == 0:

            self._network.analytic_adaptor2 = nn.Linear(dim, self._total_classes, bias=False).to(self._device)
            self.auto_cor2 = torch.zeros(dim, dim).to(self._device)
            self.crs_cor2 = torch.zeros(dim, self._total_classes).to(self._device)
        else :

            self._network.analytic_adaptor2 = nn.Linear(dim, self._total_classes-self._known_classes, bias=False).to(self._device)
            self.auto_cor2 = torch.zeros(dim, dim).to(self._device)
            self.crs_cor2 = torch.zeros(dim, self._total_classes-self._known_classes).to(self._device)
            
    def train(self, train_loader):

        # hyperparameters
        cyc_epoch = self.args["cyc_epoch"] # iterations of pseudo-labels updating
        reduction = self.args["reduction"] # \theta, the threshold for feature dimensionality reduction 
        # self.args["gau_std"]  # \sigma The standard deviation to sample intra-class weights 

        self._network.to(self._device)
        self._network.eval()

        class_to_label=self.data_manager._class_to_label
        current_labels=class_to_label[self._known_classes:self._total_classes] 
        
        # text feature 
        templates=self.data_manager._data_to_prompt[0]
        
        texts_current=[templates.format(inst) for inst in current_labels]
        texts_current = self._network.tokenizer(texts_current).to(self._device)
        text_features=self._network.model.encode_text(texts_current)
        text_features_raw = text_features / text_features.norm(dim=-1, keepdim=True)

        # self._network.analytic_adaptor : The incremental image classifier (\hat{W} in Eq 12)
        # self._network.analytic_adaptor2 : classifier for Pseudo label updating (\hat{W}' in Eq 5)
        self.init_adapter2(self.feature_dim)

        if self._cur_task == 0:
            self._network.analytic_adaptor = nn.Linear(self.feature_dim, self._total_classes, bias=False).to(self._device)
            # Initialize matrix A and C
            self.auto_cor = torch.zeros(self.feature_dim, self.feature_dim).to(self._device) 
            self.crs_cor = torch.zeros(self.feature_dim, self._total_classes).to(self._device)


            
        else :
            current_classes = self.data_manager.get_task_size(self._cur_task)
            # expand the dimension of crs_cor
            self.crs_cor = torch.cat([self.crs_cor, torch.zeros(self.feature_dim, current_classes).to(self._device)], dim=1)
        
        
        
        
        

        self.backup_matrix()
        
        for cyc in range(cyc_epoch):
            self.reset_matrix()

            print('----cyc epoch {}-----'.format(cyc))


            with torch.no_grad():


                for  images in train_loader:
                    images = images.to(self._device)#, target.to(self._device)

                    train_features = self._network.model.encode_image(images)
                    
                    image_features_raw = train_features / train_features.norm(dim=-1, keepdim=True) 

                    # Pseudo Label 
                    if cyc == 0   :
                        if reduction :
                            
                            self.feature_dim_reduction,basse = self.feature_reduction(train_features, reduction)
                            
                            self.reduction_layer = nn.Linear(self.feature_dim, self.feature_dim_reduction, bias=False).to(self._device)
                            self.reduction_layer.weight = torch.nn.parameter.Parameter(basse.T)
                            train_features_reduction = self.reduction_layer(train_features- train_features.mean(dim=0))
    
                        logits_raw = 100 * (image_features_raw  @text_features_raw.T).detach().clone()
                        
                        
                    else :
                        logits_raw = self._network.analytic_adaptor2(train_features_reduction)
                        
                    
                    # # Initialize the classifier for Pseudo label updating
                    self.init_adapter2(self.feature_dim_reduction)
                    
                    
                    target = logits_raw.argmax(dim=1) + self._known_classes


                    prob = F.softmax(logits_raw, dim=-1)
                    entropy = - torch.sum(prob * torch.log(prob),dim=-1)

                    index = [] 
                    value = [] 
                    class_indices_all = []
                    class_num = []

                    # Intra-class re-weighting factor (Eq.9)
                    for cls in range(self._known_classes,self._total_classes):

                        class_indices = torch.where(target == cls)[0]
                        if len(class_indices)>0:
                            class_indices_all.append(class_indices)
                            class_num.append(len(class_indices))

                            class_weights = entropy[class_indices]
                            
                            sorted_indices = torch.argsort(class_weights, descending=False)

                            gaussian_weights = generate_weights(len(class_indices),std = self.args["gau_std"], device=self._device)
                            sorted_gaussian_weights = torch.sort(gaussian_weights, descending=True).values

                            index.append(class_indices[sorted_indices])
                            value.append(sorted_gaussian_weights)
                            
                    index = torch.cat(index)
                    value = torch.cat(value)
                    
                    m_matrix = torch.ones(logits_raw.shape[0]).to(self._device)
                    
                    m_matrix[index] = value

                    # Inter-class re-weighting factor (Eq.7 and Eq.10)
                    for cls in range(len(class_num)):
                        m_matrix[class_indices_all[cls]] = m_matrix[class_indices_all[cls]] * logits_raw.shape[0] / class_num[cls]*len(class_num)
                

                    M = torch.diag(m_matrix)


                    train_labels_one_hot = F.one_hot(target, self._total_classes).float()
                    
                    # Regression with re-weighting (Eq.13)
                    if cyc == cyc_epoch-1:
                        # Only update the analytic_adaptor with the pseudo label of the last pseudo-labels updating epoch
                        self.auto_cor += torch.t(train_features) @ M @ train_features
                        self.crs_cor += torch.t(train_features) @ M @ (train_labels_one_hot)

                    self.auto_cor2 += torch.t(train_features_reduction) @ M @ train_features_reduction
                    self.crs_cor2 += torch.t(train_features_reduction) @ M @ (train_labels_one_hot[:,self._known_classes:self._total_classes])
                    
            # Eq.12
            self.R2 = np.mat(self.auto_cor2.cpu().numpy() + self.args["regularization"] * np.eye(train_features_reduction.size(1))).I
            self.R2 = torch.tensor(self.R2).float().to(self._device)

            Delta2 = self.R2 @ self.crs_cor2
            self._network.analytic_adaptor2.weight = torch.nn.parameter.Parameter(torch.t(1.0 * Delta2.float()))
            
            # Only update the analytic_adaptor with the pseudo label of the last pseudo-labels updating epoch
            if cyc == cyc_epoch-1:
                self.R = np.mat(self.auto_cor.cpu().numpy() + self.args["regularization"] * np.eye(train_features.size(1))).I
                self.R = torch.tensor(self.R).float().to(self._device)
                Delta = self.R @ self.crs_cor

                self._network.analytic_adaptor.weight = torch.nn.parameter.Parameter(torch.t(1.0 * Delta.float()))


