from lib.config import cfg
from torch import optim
from utils.vision import *
from typing import Iterable
from pathlib import Path
from data.dataset import *
from timm.utils import accuracy
from torch.distributions import MultivariateNormal
from lib.random_projection import setup_RP, replace_fc, optimise_ridge_parameter
from lib.utils import *

import torch
import math
import numpy as np
import os
import copy


cls_mean = dict()
cls_cov = dict()
cls_labels = dict()
gaussian_dist = dict()
task_lora_universe_head = []

def train_one_epoch(model: torch.nn.Module, 
                    criterion, data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, max_norm: float = 0,
                    set_training_mode=True, task_id=-1, class_mask=None,  
                    alpha=0.5, reg_weight=0.006, ortho_weight=0):

    model.train(set_training_mode)
    # original_model.eval()

    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('Lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('Loss', SmoothedValue(window_size=1, fmt='{value:.4f}'))
    header = f'Train: Epoch[{epoch+1:{int(math.log10(5))+1}}/{5}]'
    
    for input, target in metric_logger.log_every(data_loader, 10, header):
        input = input.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        output = model(input, task_id=task_id, alpha=alpha)
        logits = output["logits"]

        # here is the trick to mask out classes of non-current tasks
        if True and class_mask is not None:
            mask = class_mask[task_id]
            not_mask = np.setdiff1d(np.arange(cfg.dtask.nb_classes), mask)
            not_mask = torch.tensor(not_mask, dtype=torch.int64).to(device)
            logits = logits.index_fill(dim=1, index=not_mask, value=float('-inf'))

        loss = criterion(logits, target) # base criterion (CrossEntropyLoss)

        if task_id > 0: 
            loss += reg_weight*(output["reg_W_Q"] + output["reg_W_V"]) \
                    + ortho_weight*(output["ortho_loss_Q"] + output["ortho_loss_V"])
        
        acc1, acc5 = accuracy(logits, target, topk=(1, 5))
        import sys
        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()))
            sys.exit(1)

        optimizer.zero_grad()
        loss.backward() 
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()

        torch.cuda.synchronize()
        metric_logger.update(Loss=loss.item())
        metric_logger.update(Lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['Acc@1'].update(acc1.item(), n=input.shape[0])
        metric_logger.meters['Acc@5'].update(acc5.item(), n=input.shape[0])
        
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

def evaluate_till_now_with_pred_tid(pred_task_id,  task_lora_universe_head, data_loader, device, task_id=-1, 
                                    acc_matrix=None, alpha=0.3, mapping_classes=None):
    stat_matrix = np.zeros((3, cfg.continual.n_tasks))
    criterion = torch.nn.CrossEntropyLoss().to(device)

    for tid in range(task_id + 1):
        metric_logger = MetricLogger(delimiter="  ")
        header = 'Test: [Task {}]'.format(tid + 1)

        with torch.no_grad():
            b_id = 0
            for input, target in metric_logger.log_every(data_loader[tid]["val"], 10, header):
                input = input.to(cfg.device, non_blocking=True)
                target = target.to(cfg.device, non_blocking=True)
                if mapping_classes is not None:
                    target = torch.tensor(list(map(lambda label: mapping_classes[int(label.cpu())], target))).to(cfg.device, non_blocking=True)

                pred_tid = int(pred_task_id[tid][b_id].cpu())
                output = task_lora_universe_head[pred_tid](input, alpha=alpha)
                logits = output['logits']
                loss = criterion(logits, target)
                acc1, acc5 = accuracy(logits, target, topk=(1, 5))

                metric_logger.meters['Loss'].update(loss.item())
                metric_logger.meters['Acc@1'].update(acc1.item(), n=input.shape[0])
                metric_logger.meters['Acc@5'].update(acc5.item(), n=input.shape[0])

                b_id += 1
        
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
            .format(top1=metric_logger.meters['Acc@1'], top5=metric_logger.meters['Acc@5'], losses=metric_logger.meters['Loss']))

        test_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}

        stat_matrix[0, tid] = test_stats['Acc@1']
        stat_matrix[1, tid] = test_stats['Acc@5']
        stat_matrix[2, tid] = test_stats['Loss']

        acc_matrix[tid, task_id] = test_stats['Acc@1']

    avg_stat = np.divide(np.sum(stat_matrix, axis=1), task_id + 1)

    diagonal = np.diag(acc_matrix)

    result_str = "[Average accuracy till task{}]\tAcc@1: {:.4f}\tAcc@5: {:.4f}\tLoss: {:.4f}".format(task_id + 1, avg_stat[0], avg_stat[1], avg_stat[2])

    forgetting = np.mean((np.max(acc_matrix, axis=1) - acc_matrix[:, task_id])[:task_id])
    backward = np.mean((acc_matrix[:, task_id] - diagonal)[:task_id])

    result_str += "\tForgetting: {:.4f}\tBackward: {:.4f}".format(forgetting, backward)
    print(result_str)

    return test_stats


def evaluate_till_now(data_loader, device, task_id=-1, acc_matrix=None, alpha=0.3, 
                      models=None, gaussian_dist=None, mapping_classes=None):

    with torch.no_grad():
        pred_task_id = []
        for tid in range(task_id + 1):
            knn_likelihoods = []
            for mid in range(task_id + 1):
                likelihood_batch = []
                for inputs, targets in (data_loader[tid]["val"]):
                    inputs = inputs.to(cfg.device, non_blocking=True)
                    feature = models[mid](inputs, alpha=alpha)['pre_logits']
                    max_likelihood = None
                    for m in gaussian_dist[mid]:
                        log_likelihood = torch.mean(m.log_prob(feature))
                        
                        if max_likelihood is None:
                            max_likelihood = log_likelihood
                        else:
                            max_likelihood = torch.max(max_likelihood, log_likelihood)
                    
                    likelihood_batch.append(max_likelihood)
                likelihood_batch = torch.stack(likelihood_batch)
                knn_likelihoods.append(likelihood_batch)

            knn_likelihoods = torch.stack(knn_likelihoods)

            pred_task_id.append(torch.argmax(knn_likelihoods, dim=0))
            print("Task: ", tid, pred_task_id[tid])
    
    # Compute accuracy of matching
    pred_tid_stats = [(pred_task_id[i]==i).sum().item() for i in range(task_id + 1)]
    true_tid_stats = [len(pred_task_id[i]) for i in range(task_id + 1)]
    acc_tid_stats = [pred/true for (pred,true) in zip(pred_tid_stats, true_tid_stats)]
    print("Accuracy of matching: ", sum(pred_tid_stats)/sum(true_tid_stats))
    print("Accuracy of matching according to each task: ", acc_tid_stats)

    test_stats = evaluate_till_now_with_pred_tid(pred_task_id, models, data_loader, device,task_id=task_id, 
                                                 acc_matrix=acc_matrix, alpha=alpha, mapping_classes=mapping_classes)
    
    return test_stats

def train_and_evaluate(tasks, model, criterion, data_loader, device, class_mask, acc_matrix, 
                       reg_weight, added_units=4, alpha=0.3, network=None, M=0, Q=None, G=None,
                       mapping_classes=None, ortho_weight=None):
    # Create new optimizer for each task to clear optimizer status
    for task_id in (tasks):    
        if task_id > 0:
            model.add_new_units(added_units)
            model.to(device)
            reg_weight = 0.8*reg_weight # decrease reg_weight because the pool increases its size after each task
        
        optimizer = optim.Adam(model.parameters(), lr=cfg.dtask.lr)

        for epoch in range(cfg.dtask.epochs):
            train_stats = train_one_epoch(model=model, criterion=criterion,
                                            data_loader=data_loader[task_id]['train'], optimizer=optimizer,
                                            device=device, epoch=epoch, max_norm=1.0,
                                            set_training_mode=True, task_id=task_id, class_mask=class_mask, 
                                            alpha=alpha, reg_weight=reg_weight, ortho_weight=ortho_weight)

        if cfg.dtask.output_dir and is_main_process():
            checkpoint_path = os.path.join(cfg.dtask.output_dir, 'lora_task_{}'.format(task_id), 'checkpoint/')
            Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
            # Save the current pool
            torch.save(model.Q_lora_pool_B, checkpoint_path + '/Q_lora_pool_B.pth')
            torch.save(model.Q_lora_pool_A, checkpoint_path + '/Q_lora_pool_A.pth')
            torch.save(model.V_lora_pool_B, checkpoint_path + '/V_lora_pool_B.pth')
            torch.save(model.V_lora_pool_A, checkpoint_path + '/V_lora_pool_A.pth')

            if task_id > 0: #Save aggregation weights and new Lora units
                torch.save(model.W_Q, checkpoint_path + "/W_Q.pth")
                torch.save(model.W_V, checkpoint_path + "/W_V.pth")
                torch.save(model.new_Q_lora_B, checkpoint_path + '/new_Q_lora_B.pth')
                torch.save(model.new_Q_lora_A, checkpoint_path + '/new_Q_lora_A.pth')
                torch.save(model.new_V_lora_B, checkpoint_path + '/new_V_lora_B.pth')
                torch.save(model.new_V_lora_A, checkpoint_path + '/new_V_lora_A.pth')

        log('BEGIN TO LEARN GAUSSIAN DISTRIBUTIONS OF TASK {}'.format(task_id))
        t_means, t_cov, t_label = _compute_mean(model, data_loader[task_id]['train'], device, task_id=task_id, 
                                       class_mask=None, alpha=alpha, log_folder_dir=cfg.dtask.output_dir)
        cls_mean[task_id] = t_means
        cls_cov[task_id] = t_cov
        cls_labels[task_id] = t_label

        gaussian_dist[task_id] = []
        for cluster_id in range(len(cls_mean[task_id])):
            mean = cls_mean[task_id][cluster_id]
            var = cls_cov[task_id][cluster_id]
            if var.mean() == 0:
                print("var.mean is empty")
                continue
            m = MultivariateNormal(torch.tensor(mean).to(cfg.device), (torch.tensor(var).to(cfg.device) + 1e-4 * torch.eye(mean.shape[0]).to(cfg.device)))
            gaussian_dist[task_id].append(copy.deepcopy(m))

        log('BEGIN TO LEARN LDA CLASSIFIER OF TASK {}'.format(task_id))
        # Update RP model
        network.update_backbone(copy.deepcopy(model))
        del network.fc
        network.fc=None
        new_heads = sum([len(class_mask[tid]) for tid in range(task_id + 1)])
        network.update_fc(new_heads)
        #freeze RP backbone
        for n, p in network.named_parameters():
            if 'convnet' in n:
                p.requires_grad = False

        n_parameters = sum(p.numel() for p in network.parameters() if p.requires_grad)
        log('number of params: %d' % n_parameters)
        if task_id == 0:
            W_rand = setup_RP(network, M)
        
        Y, Features_h, Q, G = replace_fc(network, data_loader[task_id]["train"], W_rand, Q, G, mapping_classes)
        ridge=optimise_ridge_parameter(Features_h,Y)
        Wo=torch.linalg.solve(G+ridge*torch.eye(G.size(dim=0)),Q).T
        network.fc.weight.data=Wo[0:network.fc.weight.shape[0],:].to(cfg.device)
        # Store task-id network
        task_lora_universe_head.append(copy.deepcopy(network))

        log('BEGIN TO EVALUATE ALL TASKS UNTIL NOW - TASK {}'.format(task_id))
        head_weight = copy.deepcopy(network.fc.weight.data) # Get latest classification weight
        for tid in range(task_id + 1):
            task_lora_universe_head[tid].fc.weight = torch.nn.Parameter(head_weight)
            
        test_stats = evaluate_till_now(data_loader, device, task_id, acc_matrix=acc_matrix, 
                                        alpha=alpha, models=task_lora_universe_head, 
                                        gaussian_dist=gaussian_dist, mapping_classes=mapping_classes)
        
        if task_id > 0:
            model.merge_units()
            
    torch.save(cls_mean, cfg.dtask.output_dir + "/cls_mean.pth")
    torch.save(cls_cov, cfg.dtask.output_dir + "/cls_cov.pth")
    torch.save(cls_labels, cfg.dtask.output_dir + "/cls_labels.pth")
    
    state_dict = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch
            }

    save_on_master(state_dict, cfg.dtask.output_dir + "/model.pth")
    
@torch.no_grad()
def _compute_mean(model: torch.nn.Module, data_loader: None, device: torch.device, task_id, class_mask=None, 
                  alpha=0.3, log_folder_dir=''):
    # Clustering the training features and compute means and covar of each cluster 
    model.eval()
    features_per_tid = []
    for inputs, targets in (data_loader):
        inputs = inputs.to(device, non_blocking=True)
        features = model(inputs, alpha=alpha)['pre_logits']
        features_per_tid.append(features.detach().cpu())
    features_per_tid = torch.cat(features_per_tid, dim=0)
    features_per_tid = torch.cat([features_per_tid], dim=0).numpy()
    from sklearn.mixture import BayesianGaussianMixture
    
    # Using Bayesian GM
    bgm = BayesianGaussianMixture(
        n_components=cfg.dtask.n_components[task_id],  # maximum number of components
        weight_concentration_prior=cfg.dtask.weight_concentration_prior, 
        covariance_type='full',
        )
    bgm.fit(features_per_tid)
    cluster_lables = bgm.predict(features_per_tid)
    cluster_means = bgm.means_
    cluster_vars = bgm.covariances_

    # Apply t-SNE
    from sklearn.manifold import TSNE
    import matplotlib.pyplot as plt
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
    X_embedded = tsne.fit_transform(features_per_tid)
    # Visualize clusters
    plt.figure(figsize=(5, 3))
    scatter = plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=cluster_lables, cmap='tab10', s=1)
    plt.colorbar(scatter, label="Cluster Label")
    plt.title("t-SNE Visualization of Clusters")
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.savefig(f"{log_folder_dir}/t-SNE kmeans {task_id}.png")
    plt.show()

    return cluster_means, cluster_vars, cluster_lables
    
    


    