import torch
import random
import os
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.pyplot as plt
import umap
import numpy as np
import os
import einops
NUM_CPUS =4

    

def get_decoder_activity(model, dataloader, task_groups):
    def get_out_decode_etc(model,x,tasks):
        controller_params= None
        if model.controller is not None:
            controller_params = model.controller.get_controller_params()
        out,decodes,_ = model(x,tasks)
        return out,controller_params, decodes, model.network.decoder.gain
        
    is_attention = model.is_attention
    is_comodulation = model.is_comodulation

    test_predictions = []
    test_targets = []
    test_decodes = []
    test_controller_params = []
    test_gains = []
    test_tasks = []
    loops = 0
    for x,y,tasks in dataloader:
        x = x.to(model.device)
        tasks = tasks.to(model.device)
        with torch.no_grad():
            out,controller_params, decodes, gain = get_out_decode_etc(model,x,tasks)
        for oi in range(len(out)):
            out[oi] = out[oi].cpu()
        test_predictions.append(out)
        test_targets.append(y)
        test_decodes.append(decodes.detach().cpu().numpy())
        if is_comodulation:
            test_gains.append(gain.detach().cpu().numpy())
        if is_attention  or is_comodulation:
            test_controller_params.append(controller_params.detach().cpu().numpy())
        test_tasks.append(tasks.detach().cpu().numpy())
    test_decodes = np.concatenate(test_decodes)
    if is_attention  or is_comodulation :
        test_controller_params = np.concatenate(test_controller_params).squeeze()
        
    return test_predictions,test_targets,test_decodes,test_tasks


def plot_proj_data(all_results_pre,all_results_att,all_results_comod,seed=0):
    def project_results(t_dec):
        #test_predictions,test_targets,test_decodes,test_gains,test_controller_params = ar
        projector = umap.UMAP()
        batch_size = t_dec.shape[0]
        data_reshaped = einops.rearrange(t_dec, "b t c -> (b t) c")
        data_proj = projector.fit_transform(data_reshaped)
        data_proj = einops.rearrange(data_proj,"(b t) c -> b t c", b = batch_size)
        return data_proj

    cmap =matplotlib.colormaps.get_cmap("tab20")

    fig,(ax1,ax2,ax3) = plt.subplots(1,3)
    fig.set_dpi(500)
    data_proj_pre = project_results(all_results_pre[2][:,None,:])
    data_proj_comod = project_results(all_results_comod[2])
    data_proj_att = project_results(all_results_att[2])
    ax1.scatter(data_proj_pre[:,0,0],data_proj_pre[:,0,1], c=np.array(cmap(0)).reshape(1,4) ,alpha=0.5)
    diff_color = 0
    
    for i in range(data_proj_comod.shape[1]):
        ax2.scatter(data_proj_att[:,i,0],data_proj_att[:,i,1], c=np.array(cmap(diff_color)).reshape(1,4), label = "task"+str(i) ,alpha=0.5)
        ax3.scatter(data_proj_comod[:,i,0],data_proj_comod[:,i,1], c=np.array(cmap(diff_color)).reshape(1,4), label = "task"+str(i) ,alpha=0.5)
        diff_color+=1
    ax2.legend()
    ax3.legend()
    ax1.set_title("Pretraining")
    ax2.set_title("Attention")
    ax3.set_title("Comodulation")
    fig.savefig(f"plots/seed_{seed}/isolated_proj_data.png")


def plot_on_same_space(pre_results, att_results,comod_results ,chosen_task = None,chosen_class = None,seed=0):
    
    pre_emb = pre_results[2] 
    comod_emb = comod_results[2]
    comod_emb = comod_emb[:,chosen_task]
    att_emb = att_results[2]
    att_emb = att_emb[:,chosen_task]
    cat_emb = np.concatenate((pre_emb,att_emb,comod_emb))
    projector = umap.UMAP()
    data_proj = projector.fit_transform(cat_emb)

    pre_proj = data_proj[:pre_emb.shape[0]]
    att_proj = data_proj[pre_emb.shape[0]:2*pre_emb.shape[0]]
    comod_proj  = data_proj[2*pre_emb.shape[0]:]

    pre_proj = pre_proj[:500]
    att_proj = att_proj[:500]
    comod_proj = comod_proj[:500]

    fig, ax = plt.subplots()
    fig.set_dpi(500)
    ax.scatter(pre_proj[:,0],pre_proj[:,1], c="blue", label = "pre",alpha=0.5)
    ax.scatter(comod_proj[:,0],comod_proj[:,1], c="red", label = "comod" ,alpha=0.5)
    ax.scatter(att_proj[:,0],att_proj[:,1], c="green", label = "att" ,alpha=0.5)
    ax.legend()
    fig.savefig(f"plots/seed_{seed}/dot_plot_task_{chosen_task}.png")


def plot_on_same_space_arrows(pre_results, att_results,comod_results ,chosen_task = None,chosen_class = None,seed=0):
    
    def color_arrows_by_targ(ax1,results,pre_proj_data,post_proj_data,color_true,color_false):
        preds = results[0]
        targets = results[1]
        task_targets = [t[chosen_task] for t in targets]
        task_targets = np.concatenate(task_targets)[:500]
        zero_indices = task_targets == 0
        one_indices = task_targets == 1
        task_preds = [t[chosen_task] for t in preds]
        task_preds = np.concatenate(task_preds)[:500]
        classif_zeros = zero_indices[:,chosen_class]
        classif_ones = one_indices[:,chosen_class]
        for i in range(classif_zeros.shape[0]):
            arrow_color = color_true if classif_ones[i]  else color_false
            ax1.arrow(pre_proj_data[i,0],pre_proj_data[i,1], post_proj_data[i,0]-pre_proj_data[i,0],post_proj_data[i,1]-pre_proj_data[i,1],color=arrow_color,width=0.001,head_width=0.04)

    cmap = matplotlib.colormaps.get_cmap("tab20")
    pre_emb = pre_results[2] 
    comod_emb = comod_results[2]
    comod_emb = comod_emb[:,chosen_task]
    att_emb = att_results[2]
    att_emb = att_emb[:,chosen_task]
    cat_emb = np.concatenate((pre_emb,att_emb,comod_emb))
    projector = umap.UMAP()
    data_proj = projector.fit_transform(cat_emb)

    pre_proj = data_proj[:pre_emb.shape[0]]
    att_proj = data_proj[pre_emb.shape[0]:2*pre_emb.shape[0]]
    comod_proj  = data_proj[2*pre_emb.shape[0]:]

    pre_proj = pre_proj[:500]
    att_proj = att_proj[:500]
    comod_proj = comod_proj[:500]

    fig, (ax1,ax2) = plt.subplots(1,2)
    fig.set_dpi(500)
    for i in range(pre_proj.shape[0]):
        ax1.arrow(pre_proj[i,0],pre_proj[i,1], att_proj[i,0]-pre_proj[i,0],att_proj[i,1]-pre_proj[i,1],color="green",width=0.001,head_width=0.04)
        ax2.arrow(pre_proj[i,0],pre_proj[i,1], comod_proj[i,0]-pre_proj[i,0],comod_proj[i,1]-pre_proj[i,1],color="red",width=0.001,head_width=0.04)
    
    ax2.set_title("Changes for Comodulation, avg rate change: {:.4f}".format(np.mean( np.linalg.norm(comod_proj-pre_proj))))
    ax1.set_title("Changes for Attention, avg rate change: {:.4f}".format(np.mean( np.linalg.norm(att_proj-pre_proj))))
    fig.savefig("plots/seed_{}/arrow_plot_task_{}.png".format(seed,chosen_task))
    fig, (ax1,ax2) = plt.subplots(1,2)
    fig.set_dpi(500)
    color_arrows_by_targ(ax1,pre_results,pre_proj,att_proj,"red","green")
    ax1.set_title("Attention")
    color_arrows_by_targ(ax2,pre_results,pre_proj,comod_proj,"red","green")
    ax2.set_title("Comodulation")
    fig.savefig("plots/seed_{}/arrow_plot_target_task_{}.png".format(seed,chosen_task))
    
def plot_embedding_results(pretrained_model,attention_model,comod_model,dataloader,task_groups,seed):
    pretrain_results = get_decoder_activity(pretrained_model,dataloader,task_groups)
    att_results = get_decoder_activity(attention_model,dataloader,task_groups)
    comod_results = get_decoder_activity(comod_model,dataloader,task_groups)
    os.mkdir("plots/seed_{}/".format(seed))
    try:
        plot_proj_data(pretrain_results, att_results,comod_results ,seed=seed)
        plt.close()
        plot_on_same_space(pretrain_results, att_results,comod_results ,seed=seed,chosen_task=0,chosen_class=5)
        plt.close()
        plot_on_same_space_arrows(pretrain_results, att_results,comod_results ,chosen_task = 0,chosen_class = 5,seed=seed)
    except Exception as e:
        print("Error plotting data :",e)

    try:
        plot_proj_data(pretrain_results, att_results,comod_results ,seed=seed)
        plt.close()
        plot_on_same_space(pretrain_results, att_results,comod_results ,seed=seed,chosen_task=3,chosen_class=1)
        plt.close()
        plot_on_same_space_arrows(pretrain_results, att_results,comod_results ,chosen_task = 3,chosen_class = 1,seed=seed)
        plt.close()
    except Exception as e:
        print(f"Error plotting data :{e}")

    

    
