import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import argparse
import torch.nn as nn
from utils.util import *
from models.ps.layer_warpper import ModelUnion
import numpy as np
from scipy.interpolate import make_interp_spline


def smooth_xy(lx, ly):
    """数据平滑处理

    :param lx: x轴数据，数组
    :param ly: y轴数据，数组
    :return: 平滑后的x、y轴数据，数组 [slx, sly]
    """
    is_flatten = False
    if len(lx) == 3:
        is_flatten = True
        f_xmin, f_xmax = lx[0], lx[2]
        interval = (lx[2] - lx[0]) / 4
        lx = [lx[0] - interval] + lx + [lx[-1] + interval]
        ly = [ly[1]] + ly + [ly[1]]
        
    x = np.array(lx)
    y = np.array(ly)
    if is_flatten:
        x_smooth = np.linspace(f_xmin, f_xmax, 100)
    else:
        x_smooth = np.linspace(x.min(), x.max(), 100)
    y_smooth = make_interp_spline(x, y, k=3)(x_smooth)
    return [x_smooth, y_smooth]


def get_args():
    parser = argparse.ArgumentParser(description='Arguments for the training purpose.')    
    parser.add_argument('--gpuNums', type=int, default=1, help='number of gpus')
    parser.add_argument('--nEpochs', type=int, default=40, help='number of epochs to train for')
    parser.add_argument('--warmup', type=int, default=2, help='the epochs for warmup')
    parser.add_argument('--lr', type=float, default=1e-1, help='Learning Rate. Default=0.1')
    parser.add_argument('--mask_lr', type=float, default=-1.0, help='Mask Learning Rate. Default=0.2')
    parser.add_argument('--optim', type=str, required=False, default="SGD",choices=["ADAM", "SGD", "ADAMW"], help='optimizer. Default=ADAM')
    parser.add_argument('--wd', type=float, required=False, default=0.0, help='weight decay. Default=0.0')
    parser.add_argument('--momentum', type=float, required=False, default=0.9, help='momentum. Default=0.9')
    parser.add_argument('--threads', type=int, default=12, help='number of threads for data loader to use')
    parser.add_argument('--backbone', type=str, required=False, default='resnet50',choices=[
        "vit_small_patch16_224", 
        "vit_base_patch16_224", 
        "resnet18", 
        "resnet34", 
        "resnet50", 
        "wide_resnet", 
        "timm_resnet18",
        "timm_resnet26", 
        "timm_resnet34", 
        "timm_resnet50", 
        "densenet121", 
        "timm_densenet121",
        ], help="backbone of the model")
    parser.add_argument('--batchSize', type=int, default=96, help='training batch size')
    # parser.add_argument('--dataset_name', type=str, required=True, help="which dataset to train")
    parser.add_argument('--resume_from', type=int, default=0, help='iteration to resume from')
    parser.add_argument('--save_path', type=str, default="chk/exp", help='path to save the model')
    parser.add_argument('--visual_file', type=str, default="", help='path to save the visual_data')
    parser.add_argument('--logname', type=str, default='ps_joint_log', help="name of the logging file")
    parser.add_argument('--chkname', type=str, default='chk/torch/resnet50-19c8e357.pth', help="name of the checkpoints folder")
    parser.add_argument('--p', type=float, required=False, default=0.5, help='end p. Default=0.5')
    parser.add_argument('--p_T', type=int, required=False, default=10, help='the update T of p. Default=10 epochs')
    parser.add_argument('--cropped', type=bool, required=False, default=False, help='crop the pic or not')
    parser.add_argument('--num_iterations', type=int, required=False, default=5, help='the iteration times of the all tasks')
    parser.add_argument('--visual_layers', type=int, required=False, default=1, help='num of visualized layers')
    parser.add_argument('--visual_start', type=int, required=False, default=0, help='the id of the start visualized layers')

    # DDP settings
    parser.add_argument('--nprocs', type=int, default=1, help='number of gpus')
    parser.add_argument('--local_rank',
                    default=-1,
                    type=int,
                    help='node rank for distributed training')
    parser.add_argument('--seed',
                        default=None,
                        type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--ip', default='127.0.0.1', type=str)
    parser.add_argument('--port', default="29500", type=str)
    
    args = parser.parse_args()
    return args

cnames = {
    'aliceblue':            '#F0F8FF',
    'antiquewhite':         '#FAEBD7',
    'aqua':                 '#00FFFF',
    'aquamarine':           '#7FFFD4',
    'azure':                '#F0FFFF',
    'beige':                '#F5F5DC',
    'bisque':               '#FFE4C4',
    'black':                '#000000',
    'blanchedalmond':       '#FFEBCD',
    'blue':                 '#0000FF',
    'blueviolet':           '#8A2BE2',
    'brown':                '#A52A2A',
    'burlywood':            '#DEB887',
    'cadetblue':            '#5F9EA0',
    'chartreuse':           '#7FFF00',
    'chocolate':            '#D2691E',
    'coral':                '#FF7F50',
    'cornflowerblue':       '#6495ED',
    'cornsilk':             '#FFF8DC',
    'crimson':              '#DC143C',
    'cyan':                 '#00FFFF',
    'darkblue':             '#00008B',
    'darkcyan':             '#008B8B',
    'darkgoldenrod':        '#B8860B',
    'darkgray':             '#A9A9A9',
    'darkgreen':            '#006400',
    'darkkhaki':            '#BDB76B',
    'darkmagenta':          '#8B008B',
    'darkolivegreen':       '#556B2F',
    'darkorange':           '#FF8C00',
    'darkorchid':           '#9932CC',
    'darkred':              '#8B0000',
    'darksalmon':           '#E9967A',
    'darkseagreen':         '#8FBC8F',
    'darkslateblue':        '#483D8B',
    'darkslategray':        '#2F4F4F',
    'darkturquoise':        '#00CED1',
    'darkviolet':           '#9400D3',
    'deeppink':             '#FF1493',
    'deepskyblue':          '#00BFFF',
    'dimgray':              '#696969',
    'dodgerblue':           '#1E90FF',
    'firebrick':            '#B22222',
    'floralwhite':          '#FFFAF0',
    'forestgreen':          '#228B22',
    'fuchsia':              '#FF00FF',
    'gainsboro':            '#DCDCDC',
    'ghostwhite':           '#F8F8FF',
    'gold':                 '#FFD700',
    'goldenrod':            '#DAA520',
    'gray':                 '#808080',
    'green':                '#008000',
    'greenyellow':          '#ADFF2F',
    'honeydew':             '#F0FFF0',
    'hotpink':              '#FF69B4',
    'indianred':            '#CD5C5C',
    'indigo':               '#4B0082',
    'ivory':                '#FFFFF0',
    'khaki':                '#F0E68C',
    'lavender':             '#E6E6FA',
    'lavenderblush':        '#FFF0F5',
    'lawngreen':            '#7CFC00',
    'lemonchiffon':         '#FFFACD',
    'lightblue':            '#ADD8E6',
    'lightcoral':           '#F08080',
    'lightcyan':            '#E0FFFF',
    'lightgoldenrodyellow': '#FAFAD2',
    'lightgreen':           '#90EE90',
    'lightgray':            '#D3D3D3',
    'lightpink':            '#FFB6C1',
    'lightsalmon':          '#FFA07A',
    'lightseagreen':        '#20B2AA',
    'lightskyblue':         '#87CEFA',
    'lightslategray':       '#778899',
    'lightsteelblue':       '#B0C4DE',
    'lightyellow':          '#FFFFE0',
    'lime':                 '#00FF00',
    'limegreen':            '#32CD32',
    'linen':                '#FAF0E6',
    'magenta':              '#FF00FF',
    'maroon':               '#800000',
    'mediumaquamarine':     '#66CDAA',
    'mediumblue':           '#0000CD',
    'mediumorchid':         '#BA55D3',
    'mediumpurple':         '#9370DB',
    'mediumseagreen':       '#3CB371',
    'mediumslateblue':      '#7B68EE',
    'mediumspringgreen':    '#00FA9A',
    'mediumturquoise':      '#48D1CC',
    'mediumvioletred':      '#C71585',
    'midnightblue':         '#191970',
    'mintcream':            '#F5FFFA',
    'mistyrose':            '#FFE4E1',
    'moccasin':             '#FFE4B5',
    'navajowhite':          '#FFDEAD',
    'navy':                 '#000080',
    'oldlace':              '#FDF5E6',
    'olive':                '#808000',
    'olivedrab':            '#6B8E23',
    'orange':               '#FFA500',
    'orangered':            '#FF4500',
    'orchid':               '#DA70D6',
    'palegoldenrod':        '#EEE8AA',
    'palegreen':            '#98FB98',
    'paleturquoise':        '#AFEEEE',
    'palevioletred':        '#DB7093',
    'papayawhip':           '#FFEFD5',
    'peachpuff':            '#FFDAB9',
    'peru':                 '#CD853F',
    'pink':                 '#FFC0CB',
    'plum':                 '#DDA0DD',
    'powderblue':           '#B0E0E6',
    'purple':               '#800080',
    'red':                  '#FF0000',
    'rosybrown':            '#BC8F8F',
    'royalblue':            '#4169E1',
    'saddlebrown':          '#8B4513',
    'salmon':               '#FA8072',
    'sandybrown':           '#FAA460',
    'seagreen':             '#2E8B57',
    'seashell':             '#FFF5EE',
    'sienna':               '#A0522D',
    'silver':               '#C0C0C0',
    'skyblue':              '#87CEEB',
    'slateblue':            '#6A5ACD',
    'slategray':            '#708090',
    'snow':                 '#FFFAFA',
    'springgreen':          '#00FF7F',
    'steelblue':            '#4682B4',
    'tan':                  '#D2B48C',
    'teal':                 '#008080',
    'thistle':              '#D8BFD8',
    'tomato':               '#FF6347',
    'turquoise':            '#40E0D0',
    'violet':               '#EE82EE',
    'wheat':                '#F5DEB3',
    'white':                '#FFFFFF',
    'whitesmoke':           '#F5F5F5',
    'yellow':               '#FFFF00',
    'yellowgreen':          '#9ACD32'
}

if __name__ == '__main__':
    args = get_args()
    args.num_classes = 1000
    root_model = build_model(args)
    output_layer_name = None
    if "wide_resnet" in args.backbone:
        output_layer_name = "linears.0"
    elif "resnet" in args.backbone:
        output_layer_name = "fc"
    elif "vit" in args.backbone:
        output_layer_name = "head"
    elif "densenet" in args.backbone:
        output_layer_name = "classifier"
    task_names = [
        "cubs_cropped",  
        "flowers",  
        "sketches",  
        "stanford_cars_cropped",  
        "wikiart"
    ]
    num_classes = [200, 102, 250, 196, 195]
    ps_load_state_dict(root_model, args.chkname, prefix=output_layer_name)
    population = ModelUnion(root_model, [None for _ in task_names], [None for _ in task_names], num_classes, task_names)
    if args.resume_from > 0:
        population.load_models(args.save_path + f"_{args.resume_from - 1}", prefix=output_layer_name)
    res = dict()
    name2id = dict()
    name2inchannels = dict()
    id_count = 0
    SCALE = 12
    MAX_CHANNELS = 0
    module2relation2count = dict()
    relation2count = dict()
    for idx, m_source in enumerate(population.models):
        for offset, m_target in enumerate(population.models[idx + 1:]):
            for n, source_module in m_source.named_modules():
                if type(source_module) in [nn.Conv2d, nn.Linear] and n != output_layer_name:
                    target_module = m_target.get_submodule(n)
                    weight_shape = source_module.weight.shape
                    c = weight_shape[0]
                    MAX_CHANNELS = max(MAX_CHANNELS, c)
                    if n not in res.keys():
                        res[n] = [[(idx * c * SCALE + j * SCALE, id_count * 2 + 1) for j in range(c)]]
                        name2id[n] = id_count
                        name2inchannels[n] = source_module.in_channels
                        id_count += 1
                    if len(res[n]) < idx + 2 + offset:
                        res[n].append([((idx + 1 + offset) * c * SCALE + j * SCALE, name2id[n] * 2 + 1) for j in range(c)])
                    is_sources = ((target_module.weight.reshape(c, -1) - source_module.weight.reshape(c, -1)).sum(dim=-1) < 1e-5)
                    for c_idx, is_source in enumerate(is_sources):
                        if is_source:
                            res[n][idx + 1 + offset][c_idx] = res[n][idx][c_idx]
                    if n not in module2relation2count.keys():
                        module2relation2count[n] = dict()
                    relation = f"{population.task_names[idx]}->{population.task_names[idx + 1 + offset]}"
                    module2relation2count[n][relation] = int(is_sources.sum())
                    if relation in relation2count.keys():
                        relation2count[relation] += int(is_sources.sum())
                    else:
                        relation2count[relation] = int(is_sources.sum())
    print(relation2count)
    if args.visual_layers == -1:
        args.visual_layers = len(name2inchannels)
    print(name2inchannels)
    plt.figure(figsize=(25, 10 * args.visual_layers), dpi=100)
    plt.axis('off')
    layer_count = 0
    colors = [
                "#536FC6",
                "#90CC74",
                "#FAC758",
                "#ED6666",
                "#73BFDE",
                "#3BA272",
            ]
    for k, v in res.items():
        layer_count += 1
        if layer_count <= args.visual_start:
            continue
        print(f"processing {layer_count}'th/{len(name2inchannels)} layer: {k}")
        for m_idx, m_poses in enumerate(v):
            c = len(m_poses)
            color = colors[m_idx]
            for j, (x1, y1) in enumerate(m_poses):
                in_channels = name2inchannels[k]
                for in_c in range(in_channels):
                    offset = (m_idx - len(v) // 2 + 2) * MAX_CHANNELS
                    x = [
                        m_idx * MAX_CHANNELS + (MAX_CHANNELS // in_channels) * in_c + offset, 
                        x1 + 4 * m_idx,
                        m_idx * MAX_CHANNELS + (MAX_CHANNELS // c) * j + offset
                        ]
                    y = [
                        name2id[k] * 2, 
                        name2id[k] * 2 + 1, 
                        name2id[k] * 2 + 2
                        ]
                    sy, sx = smooth_xy(y, x)
                    plt.plot(sx, sy, color=color, linewidth= 0.75/(in_channels), zorder=1)
                    plt.scatter(x[1], y[1], color=color, marker="D", zorder=2)
        if layer_count - args.visual_start == args.visual_layers:
            break
    plt.savefig(f'./visual/model_{layer_count}.png', transparent=True, dpi=200, bbox_inches='tight')
    plt.show()
