import numpy as np
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import torch 
def get_decoder_activity(model, dataloader):
    def get_out_decode_etc(model,x,tasks):
        controller_params= None
        if model.controller is not None:
            controller_params = model.controller.get_controller_params(x,tasks)
        out,decodes,_ = model(x,tasks)
        return out,controller_params, decodes, model.network.decoder.gain
        
    test_predictions = []
    test_targets = []
    test_decodes = []
    test_tasks = []
    for x,y,tasks in dataloader:
        x = x.to(model.device)
        tasks = tasks.to(model.device)
        with torch.no_grad():
            out,controller_params, decodes, gain = get_out_decode_etc(model,x,tasks)
        for oi in range(len(out)):
            out[oi] = out[oi].cpu()
        test_predictions.append(out.detach().cpu())
        test_targets.append(y.detach().cpu())
        test_decodes.append(decodes.detach().cpu())
        test_tasks.append(tasks.detach().cpu())
    test_decodes = torch.cat(test_decodes)
        
    test_predictions = torch.cat(test_predictions)
    test_targets = torch.cat(test_targets)
    test_tasks = torch.cat(test_tasks)
    return test_predictions,test_targets,test_decodes,test_tasks


def print_lda_all(pre_results, att_results,comod_results,plot_labels=False,plot_name=None):
    tasks = pre_results[-1] 
    targets = pre_results[1]
    
    assert (tasks ==att_results[-1]).all() and (tasks == comod_results[-1]).all()
    #chosen_task_idx = tasks == chosen_task
    #targets = targets[chosen_task_idx]
    pre_emb = pre_results[2]
    att_emb = att_results[2]
    comod_emb = comod_results[2]# Sanity check / (comod_results[-2][chosen_task_idx]+1e-7)

    pre_lda = LinearDiscriminantAnalysis(n_components=80).fit(pre_emb, targets)
    
    att_lda = LinearDiscriminantAnalysis(n_components=80).fit(att_emb, targets)

    com_lda = LinearDiscriminantAnalysis(n_components=80).fit(comod_emb, targets)
    
    fig,ax = plt.subplots(1,1)
    
    ax.plot(np.cumsum(pre_lda.explained_variance_ratio_),label="Pretrained")
    ax.plot(np.cumsum(att_lda.explained_variance_ratio_),label="Attention")
    ax.plot(np.cumsum(com_lda.explained_variance_ratio_),label="Comodulation")

    ax.legend()
    if plot_name is not None:
        fig.savefig(plot_name,format="pdf",bbox_inches="tight")
    plt.close()
    return np.array((pre_lda.explained_variance_ratio_,att_lda.explained_variance_ratio_,com_lda.explained_variance_ratio_));

def print_pca_all(pre_results, att_results,comod_results,plot_name=None):
    tasks = pre_results[-1] 
    targets = pre_results[1]
    
    assert (tasks ==att_results[-1]).all() and (tasks == comod_results[-1]).all()
    targets = targets
    pre_emb = pre_results[2]
    att_emb = att_results[2]
    comod_emb = comod_results[2]# Sanity check / (comod_results[-2][chosen_task_idx]+1e-7)

    pre_lda = PCA(n_components=80).fit(pre_emb, targets)
    
    att_lda = PCA(n_components=80).fit(att_emb, targets)

    com_lda = PCA(n_components=80).fit(comod_emb, targets)

    fig,ax = plt.subplots(1,1)
    
    ax.plot(np.cumsum(pre_lda.explained_variance_ratio_),label="Pretrained")
    ax.plot(np.cumsum(att_lda.explained_variance_ratio_),label="Attention")
    ax.plot(np.cumsum(com_lda.explained_variance_ratio_),label="Comodulation")

    ax.legend()
    if plot_name is not None:
        fig.savefig(plot_name,format="pdf",bbox_inches="tight")

    plt.close()
    return np.array((pre_lda.explained_variance_ratio_,att_lda.explained_variance_ratio_,com_lda.explained_variance_ratio_))

def plot_pca_lda(pre_model,att_model,com_model,dataloader,folder=None,seed=0):
    
    pre_res = get_decoder_activity(pre_model, dataloader)
    att_res = get_decoder_activity(att_model, dataloader)
    com_res = get_decoder_activity(com_model, dataloader)
    ret_dict = {}
    ret_dict["lda_expl"] = print_lda_all(pre_res, att_res, com_res,plot_name=f"{folder}/lda_cumsum_{seed}.pdf")
    ret_dict["pca_expl"] = print_pca_all(pre_res, att_res, com_res,plot_name=f"{folder}/pca_cumsum_{seed}.pdf")
    
    return ret_dict
    