#!/usr/bin/env python
# coding: utf-8



import json
from pprint import pprint

import torch
import time
import random
import numpy as np
import torchvision

import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

import torchvision.models as models
import torchvision.transforms as transforms

import timm

from scipy.spatial.distance import cdist
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)




def test_model(model, test_loader, end_id=-1):
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        number_corrects = 0
        number_samples = 0
        #ss = time.time()
        for i, (test_images_set , test_labels_set) in enumerate(test_loader):
            test_images_set = test_images_set.to(device)
            test_labels_set = test_labels_set.to(device)

            y_predicted = model(test_images_set)
            labels_predicted = y_predicted.argmax(axis = 1)
            number_corrects += (labels_predicted==test_labels_set).sum().item()
            number_samples += test_labels_set.size(0)
            if end_id != -1 and i == end_id: break
            if 1==2 and i%50 == 0:
                ee = time.time()
                print(number_corrects, number_samples, number_corrects/number_samples, ee-ss)
                ss = time.time()
            
        return (number_corrects / number_samples)*100
    
    

def forward_layers(Qs, lambdas, num_Qs):
    
    dist_Qs = np.array([cdist(Qs[Q_id][0],Qs[Q_id][1]) for Q_id in range(num_Qs)])
    weighted_sum_mat = np.sum(dist_Qs*lambdas.reshape(num_Qs,1,1), axis=0)
    return weighted_sum_mat

def rel_position(cur_q, nex_q, Q_dim):
    # cur_q, nex_q are n*d matrices
    res_pos = []
    for q_id in range(cur_q.shape[0]):
        res_pos.append(cur_q[q_id].reshape(1,Q_dim) - nex_q)
    return np.array(res_pos)


def recover_weights(cur_coords, nex_coords, lambdas):
    
    num_Qs = lambdas.shape[0]
    Q_dim = cur_coords.shape[2]
    
    rel_dist_list = np.array([5*(cdist(cur_coords[set_id], nex_coords[set_id])-0.5) for set_id in range(num_Qs)]) 
    rel_vect_list = np.array([rel_position(cur_coords[set_id], nex_coords[set_id], Q_dim) for set_id in range(num_Qs)])
    res_mat = np.sum(rel_dist_list*lambdas.reshape(num_Qs,1,1), axis=0)
    
    return res_mat


def recover_mat_from_dyn(i_coords, o_coords, lambdas, _prec=0, _scale=1):
    
    if _prec != 0: i_coords, o_coords = _prec*(i_coords//_prec), _prec*(o_coords//_prec)
    rc_mat = recover_weights(i_coords, o_coords, lambdas)/_scale
    return rc_mat


def recover_weights_from_dyn(var_name, _coords, _prec=0):
    
    _info = _coords[var_name]
    i_coords, o_coords, lambdas = _info['inputs'], _info['outputs'], _info['lambdas']
    num_Qs, Q_dim, step_size, var_scale = _info['configs']['basic_configs']
    
    rc_weight = recover_mat_from_dyn(i_coords, o_coords, lambdas, _prec=_prec, _scale=var_scale)
    
    return rc_weight

def recover_model_from_dyn(_model, _coords, _prec=0, fit_thres=0.1, trans_vars=[]):
    
    size_stat = {'fixed':0,'trans_dyn':0,'trans_raw':0}
    _trans_vars = []
    c_model = copy.deepcopy(_model)
    for var_name in _coords:
        #if trans_vars != [] and var_name not in trans_vars: continue
        _info = _coords[var_name]
        tg_param = c_model.state_dict()[var_name]
        rc_weight = recover_weights_from_dyn(var_name, _coords, _prec=_prec)
        rc_param = torch.from_numpy(rc_weight)
                
        _loss = torch.sum(abs(rc_param-tg_param))/tg_param.numel()
        _fit = (_loss/torch.var(tg_param)).item()
        #print(var_name, _fit, tg_param.numel())
        
        if (trans_vars != [] and var_name in trans_vars) or (trans_vars == [] and _fit < fit_thres):
            c_model.state_dict()[var_name].copy_(rc_param)
            _trans_vars.append(var_name)
            size_stat['trans_dyn']+=_info['inputs'].shape[0]*(_info['inputs'].shape[1]+_info['outputs'].shape[1])
            size_stat['trans_raw']+=tg_param.numel()
        else: size_stat['fixed']+=tg_param.numel()
        
    #print(_trans_vars)
    print(size_stat)
    return c_model



transform = transforms.Compose([transforms.Resize(size=(224, 224)),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
val_dataset = torchvision.datasets.ImageNet(root= './imagenet_data', split='val', transform = transform)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 32, shuffle = False)

print(len(test_loader))


swinS_model = timm.create_model('swin_small_patch4_window7_224', pretrained=True)

root_path = './imagenet_coords/'
coordFile_config = 'sep26_S2E3_8W_hj10_gd16_TP10'
loaded_dyn_coords = np.load(root_path+"swin_coords/swinS_"+coordFile_config+".npy",allow_pickle=True).item()

print('Config:', coordFile_config)

print('--- Test Prec:', 5e-3)
ss = time.time()
loaded_model = recover_model_from_dyn(swinS_model, loaded_dyn_coords, _prec=5e-3, fit_thres=0.1, trans_vars=[])
test_res = test_model(loaded_model, test_loader, end_id=-1)
ee = time.time()
print('--- Acc:', test_res, '--- TC:', ee-ss)
print()

