from audioop import add
from copy import copy, deepcopy
from dataclasses import dataclass
import math
import os
import sys
from typing import List
from matplotlib import pyplot as plt
import numpy as np
from sklearn import metrics
import torch
import torch.nn.functional as F

from eval_scripts.performance import f1_score
from tqdm import tqdm
from custom_model.vmoe_module import MultiModalityConfig
from thop import profile, clever_format
from custom_model.vmoe_module import FMoELinear
import torch.nn as nn
from sklearn.metrics import accuracy_score

from custom_model.common_models import Reshape

cus_ops = 0
def train(model:nn.Module,epochs,trains,valid
          ,test,modalities,savedir,lr=0.001
          ,weight_decay=0.0, optimizer=torch.optim.Adam, 
          criterions=[torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss(), torch.nn.MSELoss()],
          valid_criterions = None,
          test_criterions = None,
          unsqueezing=[True,True], device="cuda:0",
          train_weights=[1.0,1.0],
          eval_weights = None,
          is_affect=[False,False],
          is_hci = False,
          transpose=[False,False],calc_flops = False, encoder = None,
          is_classification = [True, True, False],
          flips=-1, classesnum=[2,10,2,2],start_from=0,
          getattentionmap=False, is_train = False, args = None, mconfig : MultiModalityConfig= None,
          grad_clip = False,
          schedular = None, goal_points = None, tune_gate_weight = False, moe_gate_weight = [1.,1.,1.]):
    if valid_criterions is None:
        valid_criterions = criterions
    if test_criterions is None:
        test_criterions = valid_criterions

    gate_weight = [0.1] * len(trains)
    if tune_gate_weight:
        gate_weight = moe_gate_weight
    num_task = len(trains)
    additional_weight = [1 for _ in range(num_task)]

    train_losses = {}
    if mconfig.outter_task_loss:
        task_spc_dist = {}

    task_topk = {}
    for i in range(len(modalities)):
        task_topk[i] = {}
        for m in modalities[i]:
            task_topk[i][m] = 2

    topk_status = {}
    for i in range(len(modalities)):
        topk_status[i] = {}
        for m in modalities[i]:
            topk_status[i][m] = True

    valid_loss = {}
    for i in range(len(modalities)):
        valid_loss[i] = {}
        for m in modalities[i]:
            valid_loss[i][m] = []
    
    topk_improve_sign = {}
    for i in range(len(modalities)):
        topk_improve_sign[i] = {}
        for m in modalities[i]:
            topk_improve_sign[i][m] = True

    ep_valid_routing_status = {}

    if is_train:
        optim = optimizer(model.parameters(),lr=lr,weight_decay=weight_decay)
        lr_schedular = None
        if schedular != None:
            # if schedular == torch.optim.lr_scheduler.CosineAnnealingLR:
            lr_schedular = schedular(optim, 150)
            # elif schedular == torch.optim.lr_scheduler.LambdaLR:
            #     pass
        bestacc=-float('inf')
        returnrecs=[]
        losses_record = {}
        for ep in range(epochs):
            print(task_topk)
            model.train(True)
            toreturnrecs=[]
            totalloss=[]
            totals=[]
            fulltrains=[]
            indivcorrects=[]
            loss_record = {}
            # load data from dataloader and prepare for multi-modalities and multi-tasks learning
            # all training data load into fulltrains
            # fulltrains - list
            # - index (count) - dict
            # --- task-id (i) - list
            # ----- multimodalities data
            for i in range(len(trains)):
                toreturnrecs.append([])
                count=0
                totalloss.append(0.0)
                totals.append(0)
                indivcorrects.append(0)    
                
                for j in trains[i]:
                    print("\r{}:{}".format(i, count), end=' ', flush=True)
                    #print('iter')
                    if count >= len(fulltrains):
                        fulltrains.append({})
                    if is_affect[i]:
                        jj=j[0]
                        if isinstance(criterions[int(ii)],torch.nn.CrossEntropyLoss):
                            jj.append((j[3].squeeze(1)>0).long())
                        else:
                            jj.append(j[3])
                        fulltrains[count][str(i)]=jj
                    else:
                        fulltrains[count][str(i)]=j
                    if i == flips:
                        j[-1] = (j[-1] + 1) % classesnum[i]
                    count += 1
                print(" ")
            # reverse all data, we can co-training all tasks for last some iterations during each epoch
            fulltrains.reverse()
            fulltrains=fulltrains[start_from:]
            for js in tqdm(fulltrains):
                # load data as frame, each frame data consist of multi-task data with multimodalities
                # js = {task_id:[modality1, modality2,...]}
                optim.zero_grad()
                losses=0.0
                if mconfig.outter_task_loss:
                    task_spc_dist = {}
                for ii in js:
                    #print(ii)
                    model.to_logits=model.to_logitslist[int(ii)]
                    indict={}
                    # organise each frame of task
                    # indict = {modality1: data, modality: data,...}
                    for i in range(len(modalities[int(ii)])):
                        if unsqueezing[int(ii)]:
                            indict[modalities[int(ii)][i]]=js[ii][i].to(device).float().unsqueeze(-1)
                        elif transpose[int(ii)]:
                            indict[modalities[ii][i]]=j[i].to(device).float().transpose(1,2)
                        else:
                            indict[modalities[int(ii)][i]]=js[ii][i].to(device).float()
                    for mod in indict:
                        indict[mod].requires_grad=True
                    if encoder is not None:
                        indict = encoder[int(ii)](indict)

                    if len(ep_valid_routing_status) > 0:
                        model.to_logits.set_modality_weight(ep_valid_routing_status[int(ii)])

                    out, sm_out = model(indict, task_id = ii, modality_topk = task_topk[int(ii)])
                    # calucate loss
                    if is_classification[int(ii)]:
                        loss = criterions[int(ii)](out, js[ii][-1].to(device).long().reshape(-1))
                    else:
                        loss=criterions[int(ii)](out,js[ii][-1].to(device))
                    # gather gate loss of moe
                    if mconfig.attn_modality_specific or mconfig.mlp_modality_specific:
                        g_loss = model.gate_loss(ii, modalities[int(ii)])
                    else:
                        g_loss = model.gate_loss(ii)
                    if mconfig.outter_task_loss:
                        task_spc_dist[int(ii)] = model.gate_logits(int(ii))
                    # summary losses of current task into multitasks loss
                    losses += loss*train_weights[int(ii)] + g_loss * gate_weight[int(ii)]

                    for m_name in sm_out:
                        mm_out = sm_out[m_name]
                        if is_classification[int(ii)]:
                            loss = criterions[int(ii)](mm_out, js[ii][-1].to(device).long().reshape(-1))
                        else:
                            loss=criterions[int(ii)](mm_out,js[ii][-1].to(device))
                        losses += loss
                    
                    total=len(js[ii][0])
                    totals[int(ii)] += total
                    totalloss[int(ii)] += loss.item()*total
                    loss_record['task.{}.train.loss'.format(str(ii))] = loss.cpu().item()
                    loss_record['task.{}.train.{}.loss'.format(str(ii), 'gate')] = g_loss.cpu().item()

                if is_hci or grad_clip:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), mconfig.grad_clip_value)
                losses.backward()
                optim.step()
                optim.zero_grad()
                if lr_schedular is not None:
                    lr_schedular.step()
                # write loss_record into losses_record
            for key in loss_record:
                if key not in losses_record:
                    losses_record[key] = []
                losses_record[key].append(loss_record[key])
            for ii in range(len(trains)):
                print("epoch "+str(ep)+" train loss dataset " +str(ii)+": "+str(totalloss[ii]/totals[ii]))
                
                train_losses[int(ii)] = totalloss[ii]/totals[ii]
                # if mconfig.dynamic_reweight:
                #     train_losses[int(ii)] = totalloss[ii]/totals[ii]
            # validate
            ep_valid_loss = {}
            
            with torch.no_grad():
                model.eval()
                accs=0.0
                valid_record = {}
                true = []
                pred = []
                for ii in range(len(valid)):
                    totalloss=0.0
                    totals=0
                    corrects=0
                    for jj in valid[ii]:
                        j=jj
                        if is_affect[ii]:
                            j=jj[0]
                            if isinstance(valid_criterions[int(ii)],torch.nn.CrossEntropyLoss):
                                jj.append((j[3].squeeze(1)>0).long())
                            else:
                                j.append(jj[3])
                        model.to_logits=model.to_logitslist[ii]
                        indict={}
                        for i in range(len(modalities[ii])):
                            if unsqueezing[ii]:
                                indict[modalities[ii][i]]=j[i].to(device).float().unsqueeze(-1)
                            elif transpose[ii]:
                                indict[modalities[ii][i]]=j[i].to(device).float().transpose(1,2)
                            else:
                                indict[modalities[ii][i]]=j[i].to(device).float()
                        if encoder is not None:
                            indict = encoder[int(ii)](indict)

                        # if len(ep_valid_routing_status) > 0:
                        #     model.to_logits.set_modality_weight(ep_valid_routing_status[int(ii)])
                        out, sm_out = model(indict, task_id = ii, modality_topk = task_topk[int(ii)])

                        

                        if is_classification[int(ii)]:
                            loss=valid_criterions[ii](out,j[-1].to(device).long().reshape(-1))
                        else:
                            loss=valid_criterions[ii](out,j[-1].to(device))
                        if is_hci:
                            pred.append(torch.argmax(out, 1))
                            true.append(j[-1])
                            
                        for m_name in sm_out:
                            mm_out = sm_out[m_name]
                            if is_classification[int(ii)]:
                                loss=valid_criterions[ii](mm_out,j[-1].to(device).long().reshape(-1))
                            else:
                                loss=valid_criterions[ii](mm_out,j[-1].to(device))
                            if ii not in ep_valid_loss:
                                ep_valid_loss[ii] = {}
                            if m_name not in ep_valid_loss[ii]:
                                ep_valid_loss[ii][m_name] = []
                            ep_valid_loss[ii][m_name].append(loss.item() * len(j[0]))
                            
                        router_logits = model.gate_logits(ii)
                        # print(router_logits)
                        if int(ii) not in ep_valid_routing_status:
                            ep_valid_routing_status[ii] = {}
                        for mm in modalities[int(ii)]:
                            ep_valid_routing_status[ii][mm] = []

                        for i in range(len(modalities[int(ii)])):
                            mm = sorted(modalities[int(ii)])[i]
                            # if mm not in ep_valid_routing_status[ii]:
                            #     ep_valid_routing_status[ii][mm] = []
                            ep_valid_routing_status[ii][mm].append(router_logits[0][i].cpu().item())


                        totalloss += loss.item()*len(j[0])
                        if is_classification[int(ii)]:
                            preds=torch.argmax(out,dim=1)
                            for i in range(len(preds)):
                                if preds[i].item()==j[-1].long()[i].item():
                                    corrects += 1
                                totals += 1
                        else:
                            totals += 1
                    valid_record['task.{}.training_weight'.format(str(ii))] = train_weights[int(ii)]
                    if is_classification[int(ii)]:
                        acc=float(corrects)/totals
                        
                        if is_hci:
                            true = torch.cat(true, 0).cpu().numpy()
                            pred = torch.cat(pred, 0).cpu().numpy()
                            print(accuracy_score(true, pred))
                        
                        if eval_weights is None:
                            accs += acc
                        else:
                            accs += acc * eval_weights[int(ii)]
                        print("epoch "+str(ep)+" valid loss dataset"+str(ii)+": "+str(totalloss/totals)+" acc: "+str(acc))
                        valid_record['task.{}.valid.loss'.format(str(ii))] = totalloss/totals
                        valid_record['task.{}.valid.acc'.format(str(ii))] = acc

                    else:
                        loss = totalloss/totals
                        # loss = train_losses[int(ii)]
                        if eval_weights is None:
                            accs -= loss
                        else:
                            accs -= loss * eval_weights[int(ii)]
                        print("epoch "+str(ep)+" valid loss dataset"+str(ii)+": "+str(totalloss/totals))
                        valid_record['task.{}.valid.loss'.format(str(ii))] = totalloss/totals
                    
                    if mconfig.dynamic_reweight:
                        gate_weight[int(ii)] = 1/(1 + np.exp(0.5 * (ep-0.5 * epochs)))
                        valid_record['task.{}.gating.weight'.format(str(ii))] = gate_weight[int(ii)]

                if accs > bestacc:
                    print("save best")
                    bestacc=accs
                    totally_count = 0
                    torch.save(model.state_dict(),savedir)
                else:
                    totally_count += 1
                # tune dynamic weight
                    # print(additional_weight)
            model.train(True)
            for ti in ep_valid_loss:
                for mm in ep_valid_loss[ti]:
                    loss = np.mean(ep_valid_loss[ti][mm])
                    if len(valid_loss[ti][mm]) == 0:
                        valid_loss[ti][mm].append(loss)
                    else:
                        if not topk_improve_sign[ti][mm]:
                            continue
                        if valid_loss[ti][mm][-1] < loss and ep // 5 == 0:
                            if not topk_status[ti][mm]:
                                topk_improve_sign[ti][mm] = False
                                task_topk[ti][mm] -= 1
                            else:
                                topk_status[ti][mm] = False
                                task_topk[ti][mm] += 1
                        else:
                            if loss < np.min(valid_loss[ti][mm]):
                                topk_improve_sign[ti][mm] = True
                        valid_loss[ti][mm].append(loss)

            for ti in ep_valid_routing_status:
                inner_task_weight = []
                for mm in ep_valid_routing_status[ti]:
                    inner_task_weight.append(np.mean(ep_valid_routing_status[ti][mm]))
                    # ep_valid_routing_status[ti][mm] = np.mean(ep_valid_routing_status[ti][mm])
                inner_task_weight = torch.tensor(inner_task_weight) / torch.tensor(inner_task_weight).sum()
                m_idx = 0
                for mm in ep_valid_routing_status[ti]:
                    ep_valid_routing_status[ti][mm] = 1-inner_task_weight[m_idx].item()
                    m_idx += 1

            for key in valid_record:
                if key not in losses_record:
                    losses_record[key] = []
                losses_record[key].append(valid_record[key])

    model.to_logits=model.to_logitslist[-1]
    model.load_state_dict(torch.load(savedir))
    model.eval()

    for mn, m in model.named_modules():
        if hasattr(m, 'args'):
            if hasattr(m.args, 'load_expert_count'):
                m.args.load_expert_count = True
                m.args.capacity_ratios = mconfig.capacity_ratios
                m.args.capacity_ratio = mconfig.capacity_ratio

    with torch.no_grad():
        for ii in range(len(test)):
            model.to_logits=model.to_logitslist[ii]
            totalloss=0.0
            totals=0
            corrects=0
            for jj in test[ii]:            
                j=jj
                if is_affect[ii]:
                    j=jj[0]
                    j.append((jj[3].squeeze(1) > 0).long())
                indict={}
                for i in range(0,len(modalities[ii])):
                    if unsqueezing[ii]:
                        indict[modalities[ii][i]]=j[i].float().unsqueeze(-1).to(device)
                    elif transpose[ii]:
                        indict[modalities[ii][i]]=j[i].float().to(device).transpose(1,2)
                    else:
                        # print(jj, j[i])
                        indict[modalities[ii][i]]=j[i].float().to(device)
                if encoder is not None:
                    indict =  encoder[int(ii)](indict)
                if len(ep_valid_routing_status) > 0:
                    model.to_logits.set_modality_weight(ep_valid_routing_status[int(ii)])

                out, sm_out=model(indict, task_id = ii, modality_topk = task_topk[int(ii)])
                
                if is_classification[int(ii)]:
                    preds=torch.argmax(out,dim=1)
                    # print(preds.cpu()==j[-1].long())
                    for i in range(len(preds)):
                        if preds[i].item()==j[-1].long()[i].item():
                            corrects += 1
                        totals += 1
                else:
                    loss=test_criterions[ii](out,j[-1].to(device))
                    # print(loss)
                    totalloss += loss.item() #* len(j[0])
                    totals += 1
            if is_classification[int(ii)]:
                acc=float(corrects)/totals
                print("test acc dataset "+str(ii)+": "+str(ii)+" "+str(acc))
            else:
                print("test loss dataset "+str(ii)+": "+str(totalloss/totals))

def annealing_result(min_eta, max_eta, index, temp):
    return min_eta + 0.5 * (max_eta - min_eta) * (1 + math.cos(index * math.pi / temp))

def multi_cycle(metrics: list, d_length, goal_point):
    data = metrics[-d_length:]
    data = np.array(data)
    mean_value = data.mean()
    if goal_point > 0:
        if mean_value / goal_point <= 0.9:
            return False
        else:
            return True
    else:
        return False
    # if np.abs(data.std()) * 10 > np.abs(data.mean()):
    #     return False
    # else:
    #     return True # stable

def outter_task_loss(gate_outputs):
    inner_layer = {}
    for t_id in gate_outputs:
        for l_id in gate_outputs[t_id]:
            if l_id not in inner_layer:
                inner_layer[l_id] = {}
            inner_layer[l_id][t_id] = gate_outputs[t_id][l_id]
            # print(gate_outputs[t_id][l_id].shape)
    loss = 0
    for lidx in inner_layer:
        for t_id in inner_layer[lidx]:
            cep_inp = inner_layer[lidx][t_id].unsqueeze(0)
            norm_scale = len(inner_layer[lidx])-1
            # calc norm value
            for t2_id in inner_layer[lidx]:
                if t2_id != t_id:
                    loss += F.mse_loss(cep_inp.softmax(dim=1), inner_layer[lidx][t2_id].unsqueeze(0).softmax(dim=1))/norm_scale
                    # loss += F.cross_entropy(cep_inp, inner_layer[lidx][t2_id].unsqueeze(0).softmax(dim=1))/norm_scale
    # print(loss)
    return loss

def mis_peak_reweight(fix_portion:float, fix_weight:list, index: list, lr, temp: list, multi, metris: list = None,goal_points = None):
    new_weight = []
    for i in range(len(fix_weight)):
        # temp = temp * multi ** (index // temp)
        index[i] += 1
        new_weight.append(fix_weight[i] * fix_portion * (annealing_result(lr * 0.01, lr, index[i], temp[i]) / lr))
        if index[i] != 0 and index[i] % temp[i] == 0:
            if not multi_cycle(metris[i], 5, goal_points[i]):
                temp[i] = temp[i] * 2
            index[i] = 0

        
        
    return new_weight

def dynamic_weight_gather(old_record:list, dynamic_count:list, cur_record:list, fix_portion:float, fix_weight:list):
    new_weight = []
    for i in range(len(old_record)):
        if old_record[i] > cur_record[i]:
            dynamic_count[i] += 1
            dynamic_count[i] = dynamic_count[i] if dynamic_count[i] <= 10 else 10
        else:
            # dynamic_count[i] -= 1
            dynamic_count[i] = dynamic_count[i] if dynamic_count[i] >= 0 else 0
        new_weight.append(fix_weight[i] * float(np.exp(-1. * dynamic_count[i] / 2)) * fix_portion)
        old_record[i] = cur_record[i]
        
    return new_weight

def gather_route_info(one_step, all_steps):
    assert all_steps is not dict, "all_steps must be dict type"
    for task_id in one_step:
        if task_id not in all_steps:
            all_steps[task_id] = {}
        for modality in one_step[task_id]:
            if modality not in all_steps[task_id]:
                all_steps[task_id][modality] = {}
            for layer in one_step[task_id][modality]:
                if layer not in all_steps[task_id][modality]:
                    all_steps[task_id][modality][layer] = {}
                for spc_layer_name in one_step[task_id][modality][layer]:
                    if spc_layer_name not in all_steps[task_id][modality][layer]:
                        all_steps[task_id][modality][layer][spc_layer_name] = \
                            one_step[task_id][modality][layer][spc_layer_name]
                    else:
                        all_steps[task_id][modality][layer][spc_layer_name] += \
                            one_step[task_id][modality][layer][spc_layer_name]

import copy
def draw_expert_count(all_steps, modalities, root_path, seed):
    # ['image_1','image_2'],['image','audio'],['pose', 'sensor', 'trajectory', 'control']
    if len(modalities) == 2:
        modality_title_remap = {
            'image': 'V&T Image:',
            'force': 'V&T Force:',
            'proprio': 'V&T Proprio:',
            'depth': 'V&T Depth:',
            'action': 'V&T Action:',
            'timeseries_gripper_pos': 'PUSH Proprioception:',
            'timeseries_gripper_sensors': 'PUSH Force:',
            'colorlessimage_timeseries': 'PUSH Image:',
            'timeseries_control': 'PUSH Control:'
        }
    elif len(modalities) == 3:
        modality_title_remap = {
            'image_1': 'ENRICO Image:',
            'image_2': 'ENRICO Set:',
            'image': 'AV-MNIST Image:',
            'audio': 'AV-MNIST Audio:',
            'pose': 'PUSH Proprioception:',
            'sensor': 'PUSH Force:',
            'trajectory': 'PUSH Image:',
            'control': 'PUSH Control:'
        }
    elif len(modalities) == 4:
        modality_title_remap = {
            'colorlessimage': "AV-MNIST Image:",
            'audiospec': 'AV-MNIST Audio:',
            'feature1': 'MOSEI Image:',
            'feature2': 'MOSEI Audio:',
            'feature3': 'MOSEI/UR-FUNNY Text:',
            'feature4': 'UR-FUNNY Image:',
            'feature5': 'UR-FUNNY Audio:',
            'static': 'MIMIC Static:',
            'timeseries': 'MIMIC Timeseries:'
        }
    layer_title_remap = {
        'attn_expert_count_k': 'Key MoE Layer',
        'attn_expert_count_q': 'Query MoE Layer',
        'attn_expert_count_v': 'Value MoE Layer',
        'mlp_expert_count': 'FNN MoE Layer'
    }
    layer_gather = {}
    inner_task = {}
    for task_id in all_steps:
        inner_task = {}
        for modality in all_steps[task_id]:
            for layer in all_steps[task_id][modality]:
                for spc_layer_name in all_steps[task_id][modality][layer]:
                    fig = plt.figure()
                    data = all_steps[task_id][modality][layer][spc_layer_name]
                    print(data.sum())
                    draw_data = data / (data.sum() +  + 1e-5) * 100
                    data_x = [i for i in range(len(data))]
                    img_name = "t{}_[{}]_l{}_{}_seed[{}].png".format(task_id, modalities[int(task_id)][int(modality)], layer, spc_layer_name, seed)
                    img_path = os.path.join(root_path, img_name)
                    layer_index = '{}_{}{}'.format(layer, spc_layer_name, seed)
                    if layer_index not in inner_task:
                        inner_task[layer_index] = copy.deepcopy(data)
                    else:
                        inner_task[layer_index] += data
                    if layer_index not in layer_gather:
                        layer_gather[layer_index] = {}
                    # inner_layer_index = '{}'.format(task_id)
                    # if inner_layer_index not in layer_gather[layer_index]:
                    #     layer_gather[layer_index][inner_layer_index] = draw_data
                    # else:
                    #     layer_gather[layer_index][inner_layer_index] += draw_data
                    plt.bar(data_x, draw_data)
                    # print(img_name, draw_data)
                    # for a, b, i in zip(data_x, draw_data, range(len(data))):
                    #     plt.text(a, b+0.01, "%.2f" % draw_data[i], ha='center', fontsize=8)
                    plt.ylabel("Expert Hit Rate (%)")
                    plt.title('{} {}'.format(modality_title_remap[modalities[int(task_id)][int(modality)]], layer_title_remap[spc_layer_name]))
                    plt.xlabel("Expert id")
                    fig.savefig(img_path)
                    plt.close()
        for li in inner_task:
            if int(task_id) not in layer_gather[li]:
                layer_gather[li][int(task_id)] = inner_task[li] / (inner_task[li].sum() + 1e-5) * 100
            else:
                layer_gather[li][int(task_id)] += inner_task[li] / (inner_task[li].sum() + 1e-5) * 100
        draw_stacking_bar(all_steps[task_id], task_id, root_path, seed, modalities)

    layer_title_remap = {}
    
    if len(modalities) == 2:
        task_remap = {
        0: 'PUSH',
        1: 'V&T',
        }
    elif len(modalities) == 3:
        task_remap = {
        0: 'ENRICO',
        1: 'AV-MNIST',
        2: 'PUSH'
    }
    elif len(modalities) == 4:
        task_remap = {
        0: 'AV-MNIST',
        1: 'MOSEI',
        2: 'UR-FUNNY',
        3: 'MIMIC'
    }
    
    # task_remap = {
    #     0: 'PUSH',
    #     1: 'V&T',
    #     2: 'PUSH'
    # }
    for layer_index in layer_gather:
        bottom_data = None
        fig = plt.figure()
        img_name = "l{}_stacking.png".format(layer_index)
        img_path = os.path.join(root_path, img_name)
        index_sum = 0
        lt_remap = {
        f'{layer_index}_attn_expert_count_k{seed}': f'Key MoE Layer - {layer_index}',
        f'{layer_index}_attn_expert_count_q{seed}': f'Query MoE Layer - {layer_index}',
        f'{layer_index}_attn_expert_count_v{seed}': f'Value MoE Layer - {layer_index}',
        f'{layer_index}_mlp_expert_count{seed}': f'FNN MoE Layer - {layer_index}'
        }
        layer_title_remap.update(lt_remap)
        for task_idx in layer_gather[layer_index]:
            index_sum += layer_gather[layer_index][task_idx].sum()
        for task_idx in layer_gather[layer_index]:
            # print(modality)
            data_x = [i for i in range(len(layer_gather[layer_index][task_idx]))]
            if bottom_data is not None:
                data_y = layer_gather[layer_index][task_idx]
                data_y = data_y / (index_sum + 1e-5) * 100
                plt.bar(data_x, data_y, bottom=bottom_data, label='{}'.format(task_remap[int(task_idx)]))
                bottom_data += data_y
            else:
                data_y = layer_gather[layer_index][task_idx]
                data_y = data_y / (index_sum + 1e-5) * 100
                plt.bar(data_x, data_y, label='{}'.format(task_remap[int(task_idx)]))
                bottom_data = data_y
        plt.title('{}'.format(layer_title_remap[layer_index]))
        plt.ylabel("Expert Hit Rate (%)")
        plt.xlabel("Expert id")
        plt.legend()
        fig.savefig(img_path)
        plt.close()

def draw_loss_curve(data, root_path, seed):
    gather_training_weight = []
    for key in data:
        if 'training_weight' in key:
            gather_training_weight.append(key)
            continue
        img_name = '{}_{}.jpg'.format(key.replace('.', '_'), seed)
        img_path = os.path.join(root_path, img_name)
        x = [i for i in range(len(data[key]))]
        fig = plt.figure()
        plt.plot(x, data[key], label = '{}'.format(key.replace('.', '_')))
        plt.ylabel("{}".format('loss' if 'loss' in key else 'acc'))
        plt.xlabel("epoch")
        fig.savefig(img_path)
        plt.close()
    # draw training_weight
    if len(gather_training_weight) > 0:
        img_name = '{}_{}.jpg'.format('training_weight', seed)
        img_path = os.path.join(root_path, img_name)
        x = [i for i in range(len(data[gather_training_weight[0]]))]
        # print(gather_training_weight)
        fig = plt.figure()
        for key in gather_training_weight:
            plt.plot(x, data[key], label = '{}'.format(key.replace('.', '_')))
            # print(data[key], '{}'.format(key.replace('.', '_')))
        plt.ylabel("{}".format('training_weight value'))
        plt.xlabel("epoch")
        plt.legend()
        fig.savefig(img_path)
        plt.close()

def draw_stacking_bar(data, task_id, root_path, seed, modalities):
    """Maybe ineffiency but easy for understanding

    Args:
        data (_type_): _description_
        task_id (_type_): _description_
    """
    layer_title_remap = {}
    # task_title_remap = {
    #     0: 'ENRICO',
    #     1: 'AV-MNIST',
    #     2: 'PUSH'
    # }
    
    if len(modalities) == 2:
        task_title_remap = {
        0: 'PUSH',
        1: 'V&T',
        }
    elif len(modalities) == 3:
        task_title_remap = {
        0: 'ENRICO',
        1: 'AV-MNIST',
        2: 'PUSH'
    }
    elif len(modalities) == 4:
        task_title_remap = {
        0: 'AV-MNIST',
        1: 'MOSEI',
        2: 'UR-FUNNY',
        3: 'MIMIC'
    }
    
    if len(modalities) == 2:
        modality_title_remap = {
            'image': 'V&T Image:',
            'force': 'V&T Force:',
            'proprio': 'V&T Proprio:',
            'depth': 'V&T Depth:',
            'action': 'V&T Action:',
            'timeseries_gripper_pos': 'PUSH Proprioception:',
            'timeseries_gripper_sensors': 'PUSH Force:',
            'colorlessimage_timeseries': 'PUSH Image:',
            'timeseries_control': 'PUSH Control:'
        }
    elif len(modalities) == 3:
        modality_title_remap = {
            'image_1': 'ENRICO Image:',
            'image_2': 'ENRICO Set:',
            'image': 'AV-MNIST Image:',
            'audio': 'AV-MNIST Audio:',
            'pose': 'PUSH Proprioception:',
            'sensor': 'PUSH Force:',
            'trajectory': 'PUSH Image:',
            'control': 'PUSH Control:'
        }
    elif len(modalities) == 4:
        modality_title_remap = {
            'colorlessimage': "AV-MNIST Image:",
            'audiospec': 'AV-MNIST Audio:',
            'feature1': 'MOSEI Image:',
            'feature2': 'MOSEI Audio:',
            'feature3': 'MOSEI/UR-FUNNY Text:',
            'feature4': 'UR-FUNNY Image:',
            'feature5': 'UR-FUNNY Audio:',
            'static': 'MIMIC Static:',
            'timeseries': 'MIMIC Timeseries:'
        }
    stacking_data = {}
    # gather information
    for modality in data:
        for layer in data[modality]:
            lt_remap = {
                f'l{layer}.layer[attn_expert_count_k]': f'Key MoE Layer - {layer}',
                f'l{layer}.layer[attn_expert_count_q]': f'Query MoE Layer - {layer}',
                f'l{layer}.layer[attn_expert_count_v]': f'Value MoE Layer - {layer}',
                f'l{layer}.layer[mlp_expert_count]': f'FNN MoE Layer - {layer}'
            }
            layer_title_remap.update(lt_remap)
            for spc_layer_name in data[modality][layer]:
                index_name = "l{}.layer[{}]".format(layer, spc_layer_name)
                if index_name not in stacking_data:
                    stacking_data[index_name] = {}
                    stacking_data[index_name][modality] = data[modality][layer][spc_layer_name]
                else:
                    stacking_data[index_name][modality] = data[modality][layer][spc_layer_name]
    # draw image
    for index_name in stacking_data:
        bottom_data = None
        fig = plt.figure()
        img_name = "t{}_[{}]_seed[{}]_stacking.png".format(task_id, index_name, seed)
        img_path = os.path.join(root_path, img_name)
        # get sum
        index_sum = 0
        for modality in stacking_data[index_name]:
            index_sum += stacking_data[index_name][modality].sum()
        for modality in stacking_data[index_name]:
            # print(modality)
            data_x = [i for i in range(len(stacking_data[index_name][modality]))]
            if bottom_data is not None:
                data_y = stacking_data[index_name][modality]
                data_y = data_y / (index_sum + 1e-5) * 100
                plt.bar(data_x, data_y, bottom=bottom_data, label='{}'.format(modality_title_remap[modalities[int(task_id)][int(modality)]]))
                bottom_data += data_y
            else:
                data_y = stacking_data[index_name][modality]
                data_y = data_y / (index_sum + 1e-5) * 100
                plt.bar(data_x, data_y, label='{}'.format(modality_title_remap[modalities[int(task_id)][int(modality)]]))
                bottom_data = data_y
            print(img_name, data_y.sum())
        plt.title('{}:{}'.format(task_title_remap[int(task_id)], layer_title_remap[index_name]))
        plt.ylabel("Expert Hit Rate (%)")
        plt.xlabel("Expert id")
        plt.legend()
        fig.savefig(img_path)
        plt.close()

def get_grads(test,model,modalities,device,unsqueezing,is_affect,transpose,mattersii):
    optimizer=torch.optim.SGD(model.parameters(),lr=0.0)
    for ii in range(len(test)):
        if ii != mattersii:
            continue
        model.to_logits=model.to_logitslist[ii]
        encoder_grads={}
        cross_grads={}
        count=0
        for jj in test[ii]: 
            count += 1
            optimizer.zero_grad()
            j=jj
            if is_affect[ii]:
                j=jj[0]
                j.append((jj[3].squeeze(1) >= 0).long())
            #if ismmimdb:
            #    j[0]=j[0].transpose(1,2
            indict={}
            for i in range(0,len(modalities[ii])):
                if unsqueezing[ii]:
                    indict[modalities[ii][i]]=j[i].float().unsqueeze(-1).to(device)
                elif transpose[ii]:
                    indict[modalities[ii][i]]=j[i].float().to(device).transpose(1,2)
                else:
                    indict[modalities[ii][i]]=j[i].float().to(device)
            out=model(indict)
            out = torch.nn.functional.softmax(out,dim=1)
            correctlabels=[out[i][j[-1][i].long().item()] for i in range(len(j[0]))]
            tograd=torch.mean(torch.stack(correctlabels))
            tograd.backward()
            for idx,param in enumerate(model.layers.parameters()):
                if str(idx) not in encoder_grads:
                    encoder_grads[str(idx)]=0.0
                encoder_grads[str(idx)] += torch.abs(param.grad.data)

            for idx,param in enumerate(model.cross_layers.parameters()):
                if str(idx) not in cross_grads:
                    cross_grads[str(idx)]=0.0
                cross_grads[str(idx)] += param.grad.data

        for idx in encoder_grads:
            encoder_grads[idx] /= count
        for idx in cross_grads:
            cross_grads[idx] /= count

    return encoder_grads,cross_grads

def gradient_blending_training(model, 
              datas, 
              task_id, 
              modality_id, 
              modalities, 
              lr, 
              optimizer: torch.optim.Optimizer, 
              criterion, 
              device, 
              is_classificiasion: bool,
              encoder,
              flips,
              classesnum: int,
              grad_clip,
              epochs,
              is_affect: bool,
              weight_decay,
              unsqueezing: bool,
              transpose: bool,
              gate_weight: float,
              mconfig : MultiModalityConfig= None,):
    optim = optimizer(model.parameters(), lr=lr, weight_decay=weight_decay)
    totalloss = []
    for ep in range(epochs):
        count = 0
        fulltrains = []
        for j in datas[task_id]:
            print("\r{}:{}".format(task_id, count), end=' ', flush=True)
            if count >= len(fulltrains):
                fulltrains.append({})
            if is_affect:
                jj = j[0]
                if isinstance(criterion, torch.nn.CrossEntropyLoss):
                    jj.append((j[3].squeeze(1)>0).long())
                else:
                    jj.append(j[3])
                fulltrains[count][str(task_id)] = jj
            else:
                fulltrains[count][str(task_id)] = j
            if task_id == flips:
                j[-1] = (j[-1] + 1) % classesnum
            count += 1
        print(' ')
        
        fulltrains.reverse()
        for js in tqdm(fulltrains):
            optim.zero_grad()
            losses = 0.
            indict = {}
            if unsqueezing:
                indict[modalities[task_id][modality_id]] = js[str(task_id)][modality_id].to(device).float().unsqueeze(-1)
            elif transpose:
                indict[modalities[task_id][modality_id]] = js[str(task_id)][modality_id].to(device).float().transpose(1, 2)
            else:
                indict[modalities[task_id][modality_id]] = js[str(task_id)][modality_id].to(device).float()
            for mod in indict:
                indict[mod].requires_grad=True
            if encoder is not None:
                indict = encoder[task_id](indict)
            
            out = model(indict, task_id = task_id)
            if is_classificiasion:
                loss = criterion(out, js[str(task_id)][-1].to(device).long().reshape(-1))
            else:
                loss = criterion(out, js[str(task_id)][-1].to(device))
            
            if mconfig.attn_modality_specific or mconfig.mlp_modality_specific:
                g_loss = model.gate_loss(task_id, [modalities[task_id][modality_id]])
            else:
                g_loss = model.gate_loss(task_id)
            # losses = loss + g_loss * gate_weight
            if mconfig.mlp_use_moe or mconfig.attn_use_moe:
                losses += loss + g_loss * gate_weight
            else:
                losses += loss
                g_loss = torch.tensor(0)
            totalloss.append(loss.item())
            
            if grad_clip:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                
            losses.backward()
            optim.step()
            


def gradient_blending_loss(model, 
              datas, 
              task_id, 
              modality_id, 
              modalities, 
              criterion, 
              device, 
              encoder,
              is_affect: bool,
              unsqueezing: bool,
              transpose: bool,
              is_classification: bool,
              eval_weight: float = None):
    with torch.no_grad():
        model.eval()
        totalloss = 0.0
        totals = 0
        corrects = 0
        accs = 0.
        # sorted_modalities = sorted(multi_modality_data.keys())
        for jj in datas[task_id]:
            j = jj
            if is_affect:
                j = jj[0]
                if isinstance(criterion,torch.nn.CrossEntropyLoss):
                    jj.append((j[3].squeeze(1)>0).long())
                else:
                    j.append(jj[3])
            indict = {}
            if unsqueezing:
                indict[modalities[task_id][modality_id]] = j[modality_id].to(device).float().unsqueeze(-1)
            elif transpose:
                indict[modalities[task_id][modality_id]] = j[modality_id].to(device).float().transpose(1,2)
            else:
                indict[modalities[task_id][modality_id]] = j[modality_id].to(device).float()
            if encoder is not None:
                indict = encoder[task_id](indict)
            out = model(indict, task_id=task_id)
            if is_classification:
                loss = criterion(out, j[-1].to(device).long().reshape(-1))
            else:
                loss = criterion(out, j[-1].to(device))
            
            totalloss += loss.item() * len(j[0])
            if is_classification:
                preds = torch.argmax(out, dim=1)
                for i in range(len(preds)):
                    if preds[i].item() == j[-1].long()[i].item():
                        corrects += 1
                    totals += 1
            else:
                totals += 1
        
        loss = totalloss/totals
        if is_classification:
            acc = float(corrects) / totals
            
            if eval_weight is None:
                accs += acc
            else:
                accs += acc * eval_weight
        model.train(True)
        if is_classification:
            return loss, accs
        else:
            return loss

def offline_gradient_blending(model: nn.Module, 
                              epochs: int,
                              trains: List,
                              valid: List,
                              test: List,
                              modalities: List,
                              savedir: str,lr=0.001,
                              weight_decay=0.0, 
                              optimizer=torch.optim.Adam, 
                              criterions=[torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss(), torch.nn.MSELoss()],
                              valid_criterions = None,
                              test_criterions = None,
                              unsqueezing=[True,True], 
                              device="cuda:0",
                              eval_weights = None,
                              train_weights = None,
                              is_affect=[False,False],
                              transpose=[False,False],
                              encoder = None,
                              is_classification = [True, True, False],
                              flips=-1, 
                              classesnum=[2,10,2,2],
                              mconfig : MultiModalityConfig= None,
                              grad_clip = False,
                              moe_gate_weight = [1.,1.,1.],
                              logitslist: torch.nn.ModuleList = None):
    if valid_criterions is None:
        valid_criterions = criterions
    if test_criterions is None:
        test_criterions = valid_criterions
    # rewaight setup
    
    # calc modality number
    modality_index = 0
    un_normalized_weight = []
    orig_to_logits = model.to_logits
    for t in range(len(modalities)):
        un_normalized_weight.append([])
        for m in range(len(modalities[t])):
            print("Gradient blending for task {}, modality:{}, starting...".format(t, modalities[t][m]))
            # load model
            model.load_state_dict(torch.load(savedir))
            # set network
            model.to_logits = logitslist[modality_index]
            modality_index += 1
            if hasattr(model, 'modalities_num'):
                for i in range(len(model.modalities_num)):
                    model.modalities_num[i] = 1
            # training 10 epoch
            # gradient_blending_training(
            #     model = model, 
            #     datas=trains, 
            #     task_id=t,
            #     modality_id=m,
            #     modalities=modalities,
            #     lr=lr * train_weights[t],
            #     optimizer=optimizer,
            #     criterion=criterions[t],
            #     device=device,
            #     is_classificiasion=is_classification[t],
            #     encoder=encoder,
            #     flips = flips,
            #     classesnum=classesnum[t],
            #     grad_clip=grad_clip,
            #     epochs=10,
            #     is_affect=is_affect[t],
            #     weight_decay=weight_decay,
            #     unsqueezing=unsqueezing[t],
            #     transpose=transpose[t],
            #     gate_weight=moe_gate_weight[t],
            #     mconfig=mconfig)
            # eval training loss and valid loss
            gb_train_loss_1 = gradient_blending_loss(
                model=model,
                datas=trains,
                task_id=t,
                modality_id=m,
                modalities=modalities,
                criterion=criterions[t],
                device=device,
                encoder=encoder,
                is_affect=is_affect[t],
                unsqueezing=unsqueezing[t],
                transpose=transpose[t],
                is_classification=is_classification[t],
                eval_weight=eval_weights[t]
            )
            print("Evaluation results on training data before finetune:{}".format(gb_train_loss_1))
            if is_classification[t]:
                gb_train_loss_1 = gb_train_loss_1[0]
            gb_valid_loss_1 = gradient_blending_loss(
                model=model,
                datas=valid,
                task_id=t,
                modality_id=m,
                modalities=modalities,
                criterion=criterions[t],
                device=device,
                encoder=encoder,
                is_affect=is_affect[t],
                unsqueezing=unsqueezing[t],
                transpose=transpose[t],
                is_classification=is_classification[t],
                eval_weight=eval_weights[t]
            )
            print("Evaluation results on valid data before finetune:{}".format(gb_train_loss_1))
            if is_classification[t]:
                gb_valid_loss_1 = gb_valid_loss_1[0]
            # training
            gradient_blending_training(
                model = model, 
                datas=trains, 
                task_id=t,
                modality_id=m,
                modalities=modalities,
                lr=lr * train_weights[t],
                optimizer=optimizer,
                criterion=criterions[t],
                device=device,
                is_classificiasion=is_classification[t],
                encoder=encoder,
                flips = flips,
                classesnum=classesnum[t],
                grad_clip=grad_clip,
                epochs=epochs,
                is_affect=is_affect[t],
                weight_decay=weight_decay,
                unsqueezing=unsqueezing[t],
                transpose=transpose[t],
                gate_weight=moe_gate_weight[t],
                mconfig=mconfig)
            # eval training loss and valid loss
            gb_train_loss_2 = gradient_blending_loss(
                model=model,
                datas=trains,
                task_id=t,
                modality_id=m,
                modalities=modalities,
                criterion=criterions[t],
                device=device,
                encoder=encoder,
                is_affect=is_affect[t],
                unsqueezing=unsqueezing[t],
                transpose=transpose[t],
                is_classification=is_classification[t],
                eval_weight=eval_weights[t]
            )
            print("Evaluation results on training data after finetune:{}".format(gb_train_loss_2))
            if is_classification[t]:
                gb_train_loss_2 = gb_train_loss_2[0]
            gb_valid_loss_2 = gradient_blending_loss(
                model=model,
                datas=valid,
                task_id=t,
                modality_id=m,
                modalities=modalities,
                criterion=criterions[t],
                device=device,
                encoder=encoder,
                is_affect=is_affect[t],
                unsqueezing=unsqueezing[t],
                transpose=transpose[t],
                is_classification=is_classification[t],
                eval_weight=eval_weights[t]
            )
            print("Evaluation results on valid data after finetune:{}".format(gb_valid_loss_2))
            if is_classification[t]:
                gb_valid_loss_2 = gb_valid_loss_2[0]
            
            gb_test_loss = gradient_blending_loss(
                model=model,
                datas=test,
                task_id=t,
                modality_id=m,
                modalities=modalities,
                criterion=criterions[t],
                device=device,
                encoder=encoder,
                is_affect=is_affect[t],
                unsqueezing=unsqueezing[t],
                transpose=transpose[t],
                is_classification=is_classification[t],
                eval_weight=eval_weights[t]
            )
            print("Evaluation results on test data after finetune:{}".format(gb_test_loss))
            
            model.to_logits = orig_to_logits
            # calc gb weight for current modality
            delta_overfitting = gb_valid_loss_2 - gb_train_loss_2
            delta_general = gb_valid_loss_2
            delta_overfitting = 0.0001 if delta_overfitting < 0 else delta_overfitting
            un_normalized_weight[t].append(
                delta_general / delta_overfitting ** 2
            )
            print('Final Result:', gb_train_loss_1, gb_valid_loss_1, gb_train_loss_2, gb_valid_loss_2, un_normalized_weight[t][-1])
            print("Gradient blending for task {}, modality:{}, Finished.".format(t, modalities[t][m]))
    print("Un-normalized result:")
    for t in range(len(un_normalized_weight)):
        for m in range(len(un_normalized_weight[t])):
            print("Score for task {}, modality:{} is:{}".format(t, modalities[t][m], un_normalized_weight[t][m]))
    
    # calculate sum and normalize
    normalized_weight = deepcopy(un_normalized_weight)
    normalized_weight_by_task = deepcopy(un_normalized_weight)
    all_sum = 0.
    task_sum = [0. for _ in range(len(un_normalized_weight))]
    for t in range(len(un_normalized_weight)):
        for m in range(len(un_normalized_weight[t])):
            all_sum += abs(un_normalized_weight[t][m])
            task_sum[t] += abs(un_normalized_weight[t][m])
            
    for t in range(len(un_normalized_weight)):
        for m in range(len(un_normalized_weight[t])):
            normalized_weight[t][m] = abs(normalized_weight[t][m]) / all_sum
            normalized_weight_by_task[t][m] = abs(normalized_weight_by_task[t][m]) / task_sum[t]
            
    # normalized_weight = np.array(un_normalized_weight) / np.array(un_normalized_weight).sum()
    print("Normalized result:")
    for t in range(len(normalized_weight)):
        for m in range(len(normalized_weight[t])):
            print("Score for task {}, modality:{} is:{}".format(t, modalities[t][m], normalized_weight[t][m]))
    print("Normalized by task result:")
    for t in range(len(normalized_weight_by_task)):
        for m in range(len(normalized_weight_by_task[t])):
            print("Score for task {}, modality:{} is:{}".format(t, modalities[t][m], normalized_weight_by_task[t][m]))
            
    return normalized_weight_by_task


def online_one_step_gb(model: nn.Module, 
                    epochs: int,
                    trains: List,
                    valid: List,
                    test: List,
                    modalities: List,
                    savedir: str,lr=0.001,
                    weight_decay=0.0, 
                    optimizer=torch.optim.Adam, 
                    criterions=[torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss(), torch.nn.MSELoss()],
                    valid_criterions = None,
                    test_criterions = None,
                    unsqueezing=[True,True], 
                    device="cuda:0",
                    eval_weights = None,
                    train_weights = None,
                    is_affect=[False,False],
                    transpose=[False,False],
                    encoder = None,
                    is_classification = [True, True, False],
                    flips=-1, 
                    classesnum=[2,10,2,2],
                    mconfig : MultiModalityConfig= None,
                    grad_clip = False,
                    moe_gate_weight = [1.,1.,1.],
                    logitslist: torch.nn.ModuleList = None):
    torch.save(model.state_dict(), savedir)
    task_sum = offline_gradient_blending(model, epochs, trains, valid, test, modalities,
                              savedir, lr, weight_decay, optimizer, criterions, valid_criterions,
                              test_criterions, unsqueezing, device, eval_weights, train_weights,
                              is_affect, transpose, encoder, is_classification,
                              flips, classesnum, mconfig=mconfig, grad_clip=grad_clip, moe_gate_weight=moe_gate_weight,
                              logitslist=logitslist)
    model.load_state_dict(torch.load(savedir))
    for i in range(model.args.modalities):
        model.modalities_num[i] = len(model.args.modalities[i])
    gl_range = [0.1, 1.5]
    if model.args.auto_gate_loss:
        model.args.gating_loss_map = {}
        modalities=model.args.modalities
        for t in range(len(modalities)):
            for m in range(len(modalities[t])):
                model.args.gating_loss_map[modalities[t][m]] = gl_range[1] - (gl_range[1]-gl_range[0]) * task_sum[t][m]