#!/usr/bin/env python
# coding: utf-8



import json
from pprint import pprint

import torch
import time
import random
import numpy as np
import torchvision

import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

import torchvision.models as models
import torchvision.transforms as transforms

import timm

from scipy.spatial.distance import cdist
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)




def test_model(model, test_loader, end_id=-1):
    model.eval()
    with torch.no_grad():
        number_corrects = 0
        number_samples = 0
        ss = time.time()
        for i, (test_images_set , test_labels_set) in enumerate(test_loader):
            test_images_set = test_images_set.to(device)
            test_labels_set = test_labels_set.to(device)

            y_predicted = model(test_images_set)
            labels_predicted = y_predicted.argmax(axis = 1)
            number_corrects += (labels_predicted==test_labels_set).sum().item()
            number_samples += test_labels_set.size(0)
            if end_id != -1 and i == end_id: break
            if i%50 == 0:
                ee = time.time()
                print(number_corrects, number_samples, number_corrects/number_samples, ee-ss)
                ss = time.time()
            
        return (number_corrects / number_samples)*100
    
    

def forward_layers(Qs, lambdas, num_Qs):
    
    dist_Qs = np.array([cdist(Qs[Q_id][0],Qs[Q_id][1]) for Q_id in range(num_Qs)])
    weighted_sum_mat = np.sum(dist_Qs*lambdas.reshape(num_Qs,1,1), axis=0)
    return weighted_sum_mat

def rel_position(cur_q, nex_q, Q_dim):
    # cur_q, nex_q are n*d matrices
    res_pos = []
    for q_id in range(cur_q.shape[0]):
        res_pos.append(cur_q[q_id].reshape(1,Q_dim) - nex_q)
    return np.array(res_pos)

def update_params_rigid(cur_coords, nex_coords, lambdas, t_mat, Q_dim, num_Qs, step_size=0.002):
    
    Q_rol = cur_coords.shape[1]
    Q_col = nex_coords.shape[1]
    
    rel_dist_list = np.array([5*(cdist(cur_coords[set_id], nex_coords[set_id])-0.5) for set_id in range(num_Qs)]) 
    rel_vect_list = np.array([rel_position(cur_coords[set_id], nex_coords[set_id], Q_dim) for set_id in range(num_Qs)])
    res_mat = np.sum(rel_dist_list*lambdas.reshape(num_Qs,1,1), axis=0)
    res_error = res_mat - t_mat
    
    for Q_id in range(num_Qs):
        
        cur_lambda = lambdas[Q_id]
        rel_dist_M = step_size*rel_dist_list[Q_id]
        rel_vect_M = step_size*rel_vect_list[Q_id]
        
        resize_error = np.repeat(res_error.reshape(Q_rol,Q_col,1),Q_dim,axis=1).reshape(Q_rol,Q_col,Q_dim)        
        _delta = 2*cur_lambda*rel_vect_M*resize_error#/reg_sum
        cur_coords[Q_id] -= np.sum(_delta,axis=1)
        nex_coords[Q_id] += np.sum(_delta,axis=0)
        
        lambdas[Q_id] -= np.sum(rel_dist_M*res_error)
        
    return cur_coords, nex_coords, lambdas, np.sum(abs(res_error))


def update_params_fluid(cur_coords, nex_coords, lambdas, t_mat, Q_dim, num_Qs, step_size=0.002):
    
    Q_rol = cur_coords.shape[1]
    Q_col = nex_coords.shape[1]
    
    rel_dist_list = np.array([5*(cdist(cur_coords[set_id], nex_coords[set_id])-0.5) for set_id in range(num_Qs)]) 
    rel_vect_list = np.array([rel_position(cur_coords[set_id], nex_coords[set_id], Q_dim) for set_id in range(num_Qs)])
    res_mat = np.sum(rel_dist_list*lambdas.reshape(num_Qs,1,1), axis=0)
    res_error = res_mat - t_mat
    
    for Q_id in range(num_Qs):
        
        cur_lambda = lambdas[Q_id]
        rel_dist_M = step_size*rel_dist_list[Q_id]
        rel_vect_M = step_size*rel_vect_list[Q_id]
        
        resize_error = np.repeat(res_error.reshape(Q_rol,Q_col,1),Q_dim,axis=1).reshape(Q_rol,Q_col,Q_dim)
        _delta = 2*cur_lambda*rel_vect_M*resize_error#/reg_sum
                
        rel_curQs = cdist(cur_coords[Q_id], cur_coords[Q_id])+1
        rel_nexQs = cdist(nex_coords[Q_id], nex_coords[Q_id])+1
        
        cur_coords[Q_id] -= (1/rel_curQs**2)@np.sum(_delta,axis=1)
        nex_coords[Q_id] += (1/rel_nexQs**2)@np.sum(_delta,axis=0)
        
        lambdas[Q_id] -= np.sum(rel_dist_M*res_error)
        
    return cur_coords, nex_coords, lambdas, np.sum(abs(res_error))
        



def init_dyn_coords(_transDict, retrain_vars=[], res_coords={}):
    
    for var_name in _transDict:
        if retrain_vars != [] and var_name not in retrain_vars: continue
        num_inputs, num_outputs, block_dimX, block_dimY = _transDict[var_name][0]
        num_Qs, Q_dim, step_size, var_scale = _transDict[var_name][1]
        res_coords[var_name] = {
            'inputs': np.random.rand(num_Qs, num_inputs, Q_dim),
            'outputs': np.random.rand(num_Qs, num_outputs, Q_dim),
            'lambdas': np.random.randn(num_Qs,1)/(0.5*num_Qs),
            'configs': {
                'split_configs': [num_inputs, num_outputs, block_dimX, block_dimY],
                'basic_configs': [num_Qs, Q_dim, step_size, var_scale]
            }
        }
        
    return res_coords

def recover_weights(cur_coords, nex_coords, lambdas):
    
    num_Qs = lambdas.shape[0]
    Q_dim = cur_coords.shape[2]
    
    rel_dist_list = np.array([5*(cdist(cur_coords[set_id], nex_coords[set_id])-0.5) for set_id in range(num_Qs)]) 
    rel_vect_list = np.array([rel_position(cur_coords[set_id], nex_coords[set_id], Q_dim) for set_id in range(num_Qs)])
    res_mat = np.sum(rel_dist_list*lambdas.reshape(num_Qs,1,1), axis=0)
    
    return res_mat

def train_single_dyn(cur_coords, nex_coords, lambdas, t_mat, split_config, Q_dim, num_Qs, step_size):
    
    row_dim, col_dim, row_step, col_step = split_config

    raw_error = 0
    for row_s in range(0,row_dim,row_step):
        for col_s in range(0,col_dim, col_step):
            cur_t_mat = t_mat[row_s:row_s+row_step,col_s:col_s+col_step]
            cur_coords[:,row_s:row_s+row_step,:], nex_coords[:,col_s:col_s+col_step,:],lambdas, _error = update_params_fluid(cur_coords[:,row_s:row_s+row_step,:], nex_coords[:,col_s:col_s+col_step,:], lambdas, cur_t_mat, Q_dim, num_Qs, step_size=step_size)
            raw_error += _error
        
    return cur_coords, nex_coords, lambdas, raw_error


def recover_mat_from_dyn(i_coords, o_coords, lambdas, _prec=0, _scale=1):
    
    if _prec != 0: i_coords, o_coords = _prec*(i_coords//_prec), _prec*(o_coords//_prec)
    rc_mat = recover_weights(i_coords, o_coords, lambdas)/_scale
    return rc_mat


def recover_weights_from_dyn(var_name, _coords, _prec=0):
    
    _info = _coords[var_name]
    i_coords, o_coords, lambdas = _info['inputs'], _info['outputs'], _info['lambdas']
    num_Qs, Q_dim, step_size, var_scale = _info['configs']['basic_configs']
    
    rc_weight = recover_mat_from_dyn(i_coords, o_coords, lambdas, _prec=0, _scale=var_scale)
    
    return rc_weight

def recover_model_from_dyn(_model, _coords, _prec=0, fit_thres=0.1, trans_vars=[]):
    
    size_stat = {'fixed':9,'trans_dyn':0,'trans_raw':0}
    _trans_vars = []
    c_model = copy.deepcopy(_model)
    for var_name in _coords:
        _info = _coords[var_name]
        tg_param = c_model.state_dict()[var_name]
        rc_weight = recover_weights_from_dyn(var_name, _coords, _prec=_prec)
        rc_param = torch.from_numpy(rc_weight)
                
        _loss = torch.sum(abs(rc_param-tg_param))/tg_param.numel()
        _fit = (_loss/torch.var(tg_param)).item()
        print(var_name, _fit)
        
        if (trans_vars != [] and var_name in trans_vars) or (trans_vars == [] and _fit < fit_thres):
            c_model.state_dict()[var_name].copy_(rc_param)
            _trans_vars.append(var_name)
            size_stat['trans_dyn']+=_info['inputs'].shape[0]*(_info['inputs'].shape[1]+_info['outputs'].shape[1])
            size_stat['trans_raw']+=tg_param.numel()
        else: size_stat['fixed']+=tg_param.numel()
        
    print(_trans_vars)
    print(size_stat)
    return c_model


def train_full_dyn(_model, _coords={}, error_thres=1, max_epoch=50000, adj_dict={}):
    
    ss = time.time()
    c_model = copy.deepcopy(_model)
    train_count = 0
    for var_name in _coords:
        if adj_dict != {} and var_name not in adj_dict: continue
        _info = _coords[var_name]
        t_mat = c_model.state_dict()[var_name].numpy()
        split_configs = _info['configs']['split_configs']
        num_Qs, Q_dim, step_size, var_scale = _info['configs']['basic_configs']
        print('----- ', var_name, ':', split_configs, [num_Qs, Q_dim, step_size, var_scale])
        
        pre_error = 0
        for _ep in range(max_epoch):
            _info['inputs'], _info['outputs'], _info['lambdas'], _error = train_single_dyn(
                _info['inputs'], _info['outputs'], _info['lambdas'], t_mat*var_scale, 
                _info['configs']['split_configs'], Q_dim, num_Qs, step_size)

            if _ep % 5000 == 0:
                ee = time.time()
                var_weight = recover_weights_from_dyn(var_name, _coords, _prec=0)
                all_error = np.sum(abs(_model.state_dict()[var_name].numpy()-var_weight))
                count_eles = split_configs[0]*split_configs[1]
                cur_error = all_error
                fit_val = cur_error/(np.var(t_mat)*count_eles)
                print(_ep, 'fit_val:', fit_val, '-- Cur Error:',cur_error, '-- Step:', pre_error-cur_error, '-- TC:', ee-ss)
                pre_error = cur_error
                ss = time.time()
                if fit_val < error_thres: break
        train_count += 1
        print('------ Current TrainID:', train_count)
        if train_count != 0 and train_count%5 == 0:
            np.save('./imagenet_coords/swin_coords/swinS_'+date_id+'_hj'+str(int(hardAdj_default*10))+'_gd'+str(int(global_dim))+'_TP'+str(train_count)+'.npy', dyn_coords)
        print()
            
    return _coords


transform = transforms.Compose([transforms.Resize(size=(224, 224)),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
val_dataset = torchvision.datasets.ImageNet(root= './imagenet_data', split='val', transform = transform)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 32, shuffle = False)

print(len(test_loader))


#resnet_model = models.resnet50(pretrained=True)
#resnet152_model = models.resnet152(pretrained=True)
#densenet_model = models.densenet121(pretrained=True)
#deitS_model = timm.create_model('deit_small_patch16_224', pretrained=True)
#inception_model = models.inception_v3(pretrained=True)
#vitS_model = timm.create_model('vit_small_r26_s32_224', pretrained=True)
swinS_model = timm.create_model('swin_small_patch4_window7_224', pretrained=True)


var_hard_adjust = {
    'layers.0.blocks.0.attn.qkv.weight':1.5,
    'layers.0.blocks.0.attn.proj.weight':2.0,
    'layers.0.blocks.0.mlp.fc1.weight':1.5,
    'layers.0.blocks.1.attn.qkv.weight':1.5,
    'layers.0.blocks.1.attn.proj.weight':2.0,
    'layers.1.blocks.0.attn.qkv.weight':1.5,
    'layers.1.blocks.0.attn.proj.weight':2.0,
}

retrain_vars = [

]


total_params = 0
trans_params = 0
global_dim = 16
hardAdj_default = 1.5
date_id = 'sep26_fld_S2E3_10W'

_swinS_transDict = {}
for name, param in swinS_model.state_dict().items():
    list_param = list(param.size())
    if len(list_param) == 2 and '.weight' in name:
        if name in var_hard_adjust: hard_adj_val = var_hard_adjust[name]
        else: hard_adj_val = hardAdj_default
            
        var_scale = round(1/torch.var(param).item())
        num_subs = round(hard_adj_val*param.numel()/(global_dim*(sum(list_param))))+2
        step_val = round(5/num_subs,3)
        
        print(name, param.size(), param.numel(), var_scale, num_subs, step_val)
        if name == 'head.weight': add_unit = [50,48]
        else: add_unit = [48,48]
        
        _swinS_transDict[name] = [list_param+add_unit,[num_subs,global_dim,step_val*1e-06,var_scale]]
    
        total_params += param.numel()
        trans_params += num_subs*sum(list_param)
    
print(total_params,trans_params)

root_path = './imagenet_coords/swin_coords/'
#dyn_coords = np.load(root_path+'swinS_sep23_hj20_gd16_10.npy',allow_pickle=True).item()
dyn_coords = init_dyn_coords(_swinS_transDict, retrain_vars=retrain_vars, res_coords={})

print(len(dyn_coords))
trained_coords = train_full_dyn(swinS_model, _coords=dyn_coords, error_thres=1e-2, max_epoch=60000)

np.save('./imagenet_coords/swin_coords/swinS_full_'+date_id+'_hj'+str(int(hardAdj_default*10))+'_gd'+str(int(global_dim))+'.npy', trained_coords)




