import io
from copy import deepcopy
from typing import Any, Union, Dict
from pathlib import Path

import matplotlib.pyplot as plt
import PIL.Image
import numpy as np
import torch
from torch.nn.modules import Module
from torch.utils.tensorboard import SummaryWriter
import torch_pruning as tp

from torchvision.transforms import ToTensor

from bypass.configs import BypassConfig
from bypass.utils import search_threshold, get_norm_vector, search_threshold_ratio, search_epsilon

def l2_distance(w1,w2,center):
    ret=torch.tensor(0.,device='cuda')
    # with torch.no_grad():
    for k,x1 in w1.items():
        x2 = w2[k]
        ret+=torch.sum((x1-x2)**2)
    # for x1, x2 in zip(w1,w2):
    #     ret+=torch.sum((x1-x2)**2)
    return ret**0.5
def cosine_sim(w1,w2,center):
    inner_product=sum([torch.sum((w1[k]-center[k])*(w2[k]-center[k])) for k in w1.keys()])
    norms_mul=(sum([torch.sum((w1[k]-center[k])**2) for k in w1.keys()])*sum([torch.sum((w2[k]-center[k])**2) for k in w2.keys()]))**0.5
    return inner_product/norms_mul
def estimate_latency(model, example_inputs, repetitions=50):
    import numpy as np
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    timings=np.zeros((repetitions,1))

    # for _ in range(5):
    #     _ = model(example_inputs)

    with torch.no_grad():
        for rep in range(repetitions):
            starter.record()
            _ = model(example_inputs)
            ender.record()
            # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[rep] = curr_time

    mean_syn = np.sum(timings) / repetitions
    std_syn = np.std(timings)
    return mean_syn, std_syn
METRICS={'l2':l2_distance,'angular':cosine_sim}

class ParameterMetric:
    def __init__(self,writer:SummaryWriter,name=None):
        self.writer=writer
        if name is None:
            self.name=str(self.__class__)
        else:
            self.name=name
        return None
    def __call__(self,model:torch.nn.Module):
        ret={self.name:0}
        return NotImplemented
    def write(self,model:torch.nn.Module,global_step:int):
        if self.writer:
            ret = self.__call__(model)
            for k, v in ret.items():
                self.writer.add_scalar(f'param_metric/{k}',v,global_step=global_step)
    

class TrajectoryMetric(ParameterMetric):
    def __init__(self,model:torch.nn.Module,writer:SummaryWriter,load_from=None,center=None,metric='l2'):
        self.writer=writer
        self.name = f'trajectory_{metric}'
        if load_from is None:
            self.starting_point=deepcopy(model.state_dict())
        if center is None:
            self.center = {k:v if k.endswith('mask') else torch.zeros_like(v) for k,v in self.starting_point.items() }
            # self.center=[torch.zeros_like(x) for x in self.starting_point]
        elif center=='base':
            self.center=self.starting_point 
            self.name = f'{self.name}_basecenter'
            
        self.metric_function=METRICS[metric]
    def _temporal_pruning_restoration(self,model):
        if isinstance(model,torch.nn.Module):
            state_dict: dict=model.state_dict() 
        else:
            state_dict = model
        ret ={k:v for k,v in state_dict.items() if not k.endswith('mask')}
        for mask_name in [x for x in state_dict.keys() if x.endswith('mask')]:
            layer_name = '.'.join(mask_name.split('.')[:-1])
            mask_type = mask_name.split('.')[-1]
            # out_type = 0 if mask_name.split('.')[-1] == 'out_mask' else 1  #0이면 out_mask, 1이면 in_mask

            pruned_indices = state_dict[mask_name]

            for k,v in {x:v for x,v in state_dict.items() if '.'.join(x.split('.')[:-1])==layer_name and not x.endswith('mask')}.items():
                if k == 'block4.10.weight':
                    print(1)
                if v.dim() == 0 :
                    continue
                if mask_type == 'in_mask' and k.endswith('bias'):
                    continue
                if v.dim() == 1:
                    total_channels = v.shape[0] + len(pruned_indices)
                    remain_indices = list(set(range(total_channels)) - set(pruned_indices.cpu().numpy()))
                    new_v_shape = [total_channels]
                    new_v = torch.zeros(*new_v_shape,device=v.device)
                    new_v[remain_indices] = ret[k]
                elif mask_type == 'out_mask':
                    num_preserved_channels =  ret[k].shape[0]
                    total_channels =  num_preserved_channels + len(pruned_indices)
                    remain_indices = list(set(range(total_channels)) - set(pruned_indices.cpu().numpy()))
                    new_v=torch.zeros(*[total_channels,*ret[k].shape[1:]],device=v.device)
                    new_v[remain_indices] = ret[k]
                elif mask_type == 'in_mask':
                    num_preserved_channels = ret[k].shape[1]
                    total_channels =  num_preserved_channels + len(pruned_indices)
                    remain_indices = list(set(range(total_channels)) - set(pruned_indices.cpu().numpy()))
                    new_v=torch.zeros(*[ret[k].shape[0],total_channels,*ret[k].shape[2:]],device=v.device)
                    new_v[:,remain_indices] = ret[k]
                else:
                    raise NameError(f'Unidentified mask name: {mask_name}')
                
                ret[k]=new_v
        return ret

    def __call__(self,model1:torch.nn.Module,model2=None):
        if model2 is None:
            model2=self.starting_point
        w1= self._temporal_pruning_restoration(model1)
        w2= self._temporal_pruning_restoration(model2)
        center = self._temporal_pruning_restoration(self.center)
        # w1=list(model1.parameters()) if isinstance(model1,torch.nn.Module) else list(model1)
        # w2=list(model2.parameters()) if isinstance(model2,torch.nn.Module) else list(model2)
        for k,v in w2.items():
            if not k in w1:
                w1[k]=torch.zeros_like(v,device=v.device)
            if not k in center:
                center[k]=torch.zeros_like(v)
        for k,v in w1.items():
            if not k in w2:
                w2[k]=torch.zeros_like(v,device=v.device)
            if not k in center:
                center[k]=torch.zeros_like(v)
        distance=self.metric_function(w1,w2,center)
        return {self.name:distance}
    
class DvecWatcher(ParameterMetric):
    def __init__(self,bypass_index:int,type:str,writer:SummaryWriter) -> None:
        self.bypass_index=bypass_index
        self.type=type
        self.writer=writer
        assert type in ['min','max','norm']
        self.name=f'D{type}{bypass_index}'

        self.ret_function=getattr(self,f'{type}_func')
        return None
    def __call__(self,model:torch.nn.Module):
        D_vec=model.bypass_layers[self.bypass_index]['D'].delta
        return {f'Dvec/{self.name}':self.ret_function(D_vec)}
    def min_func(self,dvec):
        return torch.min(abs(dvec))
    def max_func(self,dvec):
        return torch.max(abs(dvec))
    def norm_func(self,dvec):
        return torch.linalg.norm(dvec)
    
class L2normWatcher(ParameterMetric):
    def __init__(self,writer:SummaryWriter,name='weight_norm'):
        self.writer=writer
        self.name = name
        return None
    def __call__(self,model:torch.nn.Module) -> Any:
        ret=0
        with torch.no_grad():
            for name, param in model.named_parameters():
                if not 'delta' in name:
                    ret += torch.sum(param**2)
        return {self.name:ret}
class PruningDistributionMetric(ParameterMetric):
    def __init__(self,writer:SummaryWriter,name='prune_distribution'):
        self.writer =  writer
        self.name = name
        self.prune_count=1
    def __call__(self,model,idx=None):
        if idx is None:
            with torch.no_grad():
                idx = min(model.bypass_layers.keys(),key=lambda x: model.ADW_loss_single(x))
        bypass_layer_group = model.bypass_layers[idx]
        A_layer=bypass_layer_group['A'][0]
        W_layer=bypass_layer_group['W'][0]
        D_activ=bypass_layer_group['D']

        D_norm = get_norm_vector(D_activ.delta)
        while True:
            if hasattr(W_layer,'weight'):
                W_norm = get_norm_vector(W_layer.weight)
                assert D_norm.shape == W_norm.shape
                break
            else:
                W_layer_cand = model.dependancy[W_layer]['prev']
                if len(W_layer_cand)!=1:
                    W_norm = None
                    break
                    raise NotImplementedError
                W_layer = W_layer_cand[0]
        while True:
            if hasattr(A_layer,'weight'):
                A_norm = get_norm_vector(A_layer.weight,axis=1)
                if not D_norm.shape == A_norm.shape:
                    A_norm =None
                break
            else:
                A_layer_cand = model.dependancy[A_layer]['next']
                if len(A_layer_cand)!=1:
                    A_norm = None
                    break
                    raise NotImplementedError
                A_layer = A_layer_cand[0]

        epsilon = model.pruning_epsilon_dict[idx]
        ref_type, prune_indices,preserve_indices, epsilon= search_epsilon(bypass_layer_group,epsilon)
        
        def _get_pruned_norm(item):
            if item is None:
                return None
            ret = {'all':item,'preserved':item[preserve_indices],'pruned':item[prune_indices]}

            return {k:v.cpu().detach().numpy() for k,v in ret.items()}
        C_norm = D_norm/W_norm if W_norm is not None and D_norm is not None else None
        weight_norm_hist= {
            'W':_get_pruned_norm(W_norm),
            'D':_get_pruned_norm(D_norm),
            'A':_get_pruned_norm(A_norm),
            'C':_get_pruned_norm(C_norm)
        }
        # if W_norm is not None and D_norm is not None:
        #     weight_norm_hist['C'] = _get_pruned_norm(D_norm/W_norm)
        # else:
        #     weight_norm_hist['C'] = None
        
        ret={}
        plot_title = f'layer{idx} distribution'
        for ref_type in 'ADWC':
            ret[f'{self.name}/{ref_type}_plot'] = self._gen_hist(plot_title,weight_norm_hist[ref_type],epsilon=(epsilon if ref_type==ref_type else None),name=ref_type)

        # plot_title = f'layer{idx} distribution'
        # D_plot=self._gen_hist(plot_title,_get_pruned_norm(D_norm),epsilon=(epsilon if ref_type=='D' else None),name='D')
        # W_plot=self._gen_hist(plot_title,_get_pruned_norm(W_norm),epsilon=(epsilon if ref_type=='W' else None),name='W')
        # A_plot=self._gen_hist(plot_title,_get_pruned_norm(A_norm),epsilon=(epsilon if ref_type=='A' else None),name='A')

        # ret = {
        #     f'{self.name}/D_plot':D_plot,
        #     f'{self.name}/W_plot':W_plot,
        #     f'{self.name}/A_plot':A_plot
        #         }
        return ret, weight_norm_hist
    def write(self, model: Module, prune_count: int,idx=None):
        ret, weight_norm_hist = self.__call__(model,idx=idx)
        for k, v in ret.items():
            if v is None:
                continue
            image = np.array(PIL.Image.open(v).convert('RGB'))
            # image = ToTensor()(image).unsqueeze(0)
            self.writer.add_image(f'prune_metric/{k}',image,global_step=prune_count,dataformats='HWC')
        self.prune_count+=1
        torch.save(weight_norm_hist,Path(self.writer.log_dir)/f'weightnorm_layer{idx}_count{self.prune_count}.pt')


        
    def _gen_hist(self,title:str,norm_vectors:Dict[str,np.ndarray],epsilon=None,name = 'D',bins=None, logscale=True):
        """Create a pyplot plot and save to buffer."""
        if norm_vectors is None:
            return None
        plt.close('all')
        plt.figure()
        if bins is None:
            bins = (len(norm_vectors['all'])+1)//2

        if logscale:
            aa=np.histogram(np.log10(norm_vectors['all']),bins=bins)[1]
            plt.hist(np.log10(norm_vectors['preserved']),bins=aa,label = f'preserve_{name}',alpha=0.8)
            plt.hist(np.log10(norm_vectors['pruned']),bins=aa,label = f'prune_{name}',alpha=0.8)
            if epsilon is not None:
                plt.axvline(x=np.log10(epsilon),ymin=0,label='log(ε)',color='tab:green')
        # plt.xscale('log')
            plt.title(f'log distribution of {name} values')
            plt.xlabel(f'log(|{name}|)')
            plt.ylabel('counts')
        else:
            aa=10**np.histogram(np.log10(norm_vectors['all']),bins=bins)[1]
            plt.hist(norm_vectors['preserved'],bins=aa,label = f'preserve_{name}',alpha=0.8)
            plt.hist(norm_vectors['pruned'],bins=aa,label = f'prune_{name}',alpha=0.8)
            if epsilon is not None:
                plt.axvline(x=epsilon,ymin=0,label='ε',color='tab:green')
        # plt.xscale('log')
            plt.title(f'distribution of {name} values')
            plt.xlabel(f'|{name}|')
            plt.ylabel('counts')
        plt.title(f"{title}_{name}")
        plt.legend()
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        return buf


        


class GlobalPruningWatcher(ParameterMetric):
    def __init__(self,writer:SummaryWriter,input_vector_shape=[3,32,32],name="prune_metric"):
        self.dummy_input=torch.rand([1,*input_vector_shape]).cuda()
        self.writer=writer
        self.name=name
        return None
    def __call__(self,model:torch.nn.Module):
        flops_count, params_count = tp.utils.count_ops_and_params(model,example_inputs=self.dummy_input)
        # mask_counts=[(mask[1]==0).sum() for mask in model.named_buffers() if mask[0].endswith('_mask')] # mask에서 false만큼이 pruned
        # if len(mask_counts)>0:
        #     trainable_params_count=params_count - sum(mask_counts).item()
        # else:
        #     trainable_params_count = params_count
        latency_mu, latency_std = estimate_latency(model, self.dummy_input)
        ret={
            f'{self.name}/MACs(G)':flops_count/1e+9,
            f'{self.name}/Params(M)':params_count/1e+6,
            # f'{self.name}/preserve_ratio(%)':trainable_params_count/params_count*100,
            # f'{self.name}/Params(M)':trainable_params_count/1e+6,
            f'{self.name}/Latency_mean(sec)':latency_mu,
            f'{self.name}/Latency_std(sec)':latency_std,
        }
        return ret


