from typing import Tuple, Optional
from tqdm import tqdm
import matplotlib.pyplot as plt 

import torch
from torch import Tensor
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from torch.nn.modules.utils import _pair
import torch.nn.init as init

from .info_block import ModelInfo, ModuleInfo
from ..models.layers import ConvLayer

def _pca(data: Tensor) -> Tuple[Tensor, Tensor]:
    """
    Args:
        data (Tensor): size = (m,n)
    """
    data_centers = data - data.mean(dim=0, keepdim=True)
    cov_mat = data_centers.t() @ data_centers
    L, V = torch.linalg.eig(cov_mat)
    L, V = torch.real(L), torch.real(V)
    # sort
    L_sorted, idx_sorted = torch.sort(L, descending=True)
    V_sorted = V[:, idx_sorted]
    return L_sorted, V_sorted

class PCADesigner:
    def __init__(
        self,
        model_info: ModelInfo,
        data_loader: DataLoader,
        recon_rate: float = 0.95,
        sample_size: Optional[int] = None,
        device: torch.device = torch.device("cpu"),
        if_plot: bool = False
    ):
        self.model_info = model_info
        self.data_loader = data_loader
        self.recon_rate = recon_rate
        self.sample_size = sample_size
        self.device = device
        self.if_plot = if_plot

        self.design_list = [nn.Identity()]

    def run(self):
        for module in self.model_info:
            instance = None
            if module.is_undetermined():
                instance = self._search_via_pca(module)
            else:
                instance = module.get_instance()
            
            self.design_list.append(instance)
    
    def _search_via_pca(self, module: ModuleInfo) -> nn.Module:
        # get data
        data = self._get_data(module)
        # run pca
        L, V = _pca(data)
        del data
        torch.cuda.empty_cache()
        # sort
        ratio_cum = torch.cumsum(L, dim=0) / torch.sum(L)
        D_lower_bound = int(torch.where(ratio_cum > self.recon_rate)[0][0] + 1)
        V = V[:, :D_lower_bound]
        if self.if_plot:
            plt.figure()
            plt.plot(ratio_cum.cpu().numpy())
            plt.grid()
        
        # organize parameters
        for k, v in module.unknown_params.items():
            if module.meta_class == ConvLayer:
                if k == "out_channels":
                    module.set_unknown_params(k, D_lower_bound)
                    self.out_channels = D_lower_bound
                elif k == "in_channels":
                    module.set_unknown_params(k, getattr(self, "out_channels"))
            else:
                raise TypeError(f"{module.meta_class.__name__} not supported!")
        print(f"name={module.name}, class={module.meta_class}, params={module.fixed_params}+{module.unknown_params}")
        
        # instantiation
        instance = module.get_instance()

        # initialization
        if isinstance(instance, ConvLayer):
            out_c, in_c, k_h, k_w = instance.weight.shape
            if out_c != V.size(1):
                print("Final layer!")
                V = V[:, :out_c]
            instance.weight.data = V.t().reshape(out_c, in_c, k_h, k_w)
            init.zeros_(instance.bias.data)
        else:
            raise TypeError(f"Error type: {type(instance)}")

        return instance
    
    def _get_data(self, module: ModuleInfo) -> Tensor:
        _cur_model = nn.Sequential(*self.design_list).to(self.device)
        pbar = tqdm(self.data_loader)
        with torch.no_grad():
            data_list = []
            for input, _ in pbar:
                input = input.to(self.device)
                output = _cur_model(input)
                output = self._post_process(output, module)
                data_list.append(output)
            data = torch.cat(data_list, dim=0)
        # if sample
        if self.sample_size is not None:
            data = data[self.sample_size]
        
        return data

    def _post_process(self, output: Tensor, module: ModuleInfo) -> Tensor:

        if module.meta_class == ConvLayer:
            kernel_size = _pair(module.fixed_params.get("kernel_size"))
            dilation = _pair(module.fixed_params.get("dilation", 1))
            padding = _pair(module.fixed_params.get("padding", 0))
            stride = _pair(module.fixed_params.get("stride", 1))
            unfold = nn.Unfold(kernel_size=kernel_size, dilation=dilation,
                    padding=padding, stride=stride)
            ouptut_unfold = unfold(output).transpose(1, 2) # bs x (oh x ow) x (in x kh x kw)
            bs, n_p, dim = ouptut_unfold.shape
            ouptut_unfold = ouptut_unfold.reshape(-1, dim)
            return ouptut_unfold
        else:
            raise TypeError(f"Error type: {module.meta_class}")

