import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib.patches import Ellipse, Circle
from tqdm import tqdm
from .LBFGS import *
import copy
from sklearn.linear_model import LinearRegression
import sympy
import random


# ############################ LAN ###############################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def B_batch(x, grid, k=0, extend=True):
    
    # x shape: (size, x); grid shape: (size, grid)
    
    def extend_grid(grid, k_extend=0):
        # pad k to left and right
        # grid shape: (batch, grid)
        h = (grid[:,[-1]] - grid[:,[0]])/(grid.shape[1]-1)

        for i in range(k_extend):
            grid = torch.cat([grid[:,[0]]-h, grid], dim=1)
            grid = torch.cat([grid, grid[:,[-1]]+h], dim=1)
        grid = grid.to(device)
        return grid
    
    if extend == True:
        grid = extend_grid(grid, k_extend=k)
        
    grid = grid.unsqueeze(dim=2).to(device)
    x = x.unsqueeze(dim=1).to(device)
    
    if k==0:
        value = (x>=grid[:,:-1])*(x<grid[:,1:])
    else:
        B_km1 = B_batch(x[:,0],grid=grid[:,:,0],k=k-1,extend=False)
        value = (x-grid[:,:-(k+1)])/(grid[:,k:-1]-grid[:,:-(k+1)])*B_km1[:,:-1]+ (grid[:,k+1:]-x)/(grid[:,k+1:]-grid[:,1:(-k)])*B_km1[:,1:]
    return value # shape: (size, coef, batch)


def coef2curve(x_eval, grid, coef, k):
        # x_eval: (size, batch), grid: (size, grid), coef: (size, coef)
        # coef: (size, coef), B_batch: (size, coef, batch), summer over coef
        y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k))
        #y_eval = torch.sum(coef.unsqueeze(dim=1)*B_batch(x_eval, grid, k), dim=0)
        return y_eval

def curve2coef(x_eval, y_eval, grid, k):
    # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar
    mat = B_batch(x_eval, grid, k).permute(0,2,1)
    coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:,:,0] # is this stable?
    return coef.to(device)

class Spline_batch_LAN(nn.Module):

    def __init__(self, dim=2, num=5, k=3, noise_scale=0., scale_base=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), scale_sp_trainable=True, device='cpu'):
        super(Spline_batch_LAN, self).__init__()
        # size 
        self.size = size = dim
        self.num = num
        self.k = k
        self.base_fun = base_fun
        
        # shape: (size, num)
        self.grid = torch.einsum('i,j->ij', torch.ones(size,), torch.linspace(-1,1,steps=num+1))
        self.grid = torch.nn.Parameter(self.grid).requires_grad_(False)
        noises = (torch.rand(size,self.grid.shape[1])-1/2)*noise_scale
        noises = noises.to(device)
        # shape: (size, coef)
        self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k))
        self.scale_base = torch.nn.Parameter(torch.ones(size,) * scale_base) # make scale trainable
        self.scale_sp = torch.nn.Parameter(torch.ones(size,) * scale_sp).requires_grad_(scale_sp_trainable) # make scale trainable
        
    
    def forward(self, x):
        batch = x.shape[0]
        # x: shape (batch, size) => shape (size, batch)
        x = x.permute(1,0)
        preacts = x.permute(1,0).clone()
        base = self.base_fun(x).permute(1,0) # shape (batch, size)
        y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) # shape (size, batch)
        y = y.permute(1,0) # shape (batch, size)
        postspline = y.clone()
        y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y
        postacts = y.clone()
        y = postacts # shape (batch, dim)
        # y shape: (batch, dim); preacts shape: (batch, dim)
        # postspline shape: (batch, in_dim, out_dim); postacts: (batch, dim)
        # postspline is for extension; postacts is for visualization
        return y, preacts, postacts, postspline
    
    def update_grid_from_samples(self, x):
        batch = x.shape[0]
        # x: shape (batch, in_dim) => shape (size, batch)
        x = x.permute(1,0)
        x_pos = torch.sort(x, dim=1)[0]
        sp2 = Spline_batch_LAN(dim=self.size,k=1,num=x_pos.shape[1]-1,scale_base=0.).to(device)
        sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1)
        y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
        percentile = torch.linspace(-1,1,self.num+1).to(device)
        self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
        self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
        
    def initialize_grid_from_parent(self, parent, preacts):
        batch = preacts.shape[0]
        # preacts: shape (batch, in_dim) => shape (size, batch)
        x_eval = preacts.permute(1,0)
        x_pos = parent.grid
        sp2 = Spline_batch_LAN(dim=self.size,k=1,num=x_pos.shape[1]-1,scale_base=0.).to(device)
        sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1)
        y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k)
        percentile = torch.linspace(-1,1,self.num+1).to(device)
        self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
        self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k)
        
    def get_subset(self, id_):
        spb = Spline_batch_LAN(self.size, self.num, self.k)
        spb.grid.data = self.grid[id_,:]
        spb.coef.data = self.coef[id_,:]
        spb.scale_base.data = self.scale_base[id_]
        spb.scale_sp.data = self.scale_sp[id_]
        spb.size = len(id_)
        return spb


class LAN(nn.Module):

    def __init__(self, width = None, grid = None, k = None, noise_scale = 0.0, trainable_scale=True, y_scale=1.0, base_fun=torch.nn.SiLU(), scale_base=1.0, scale_sp=1.0, w0=1.0, weight_init_scale=1., linear_bias = False, scale_sp_trainable=True, mode="function", device='cpu'):
        super(LAN, self).__init__()
        
        self.biases = []
        self.linears = []
        self.act_fun = []
        self.depth = len(width) - 1
        self.width = width
        self.w0 = w0
        #self.scale_base = scale_base
        #self.scale_sp = scale_sp
        
        for l in range(self.depth):
            if l < self.depth - 1:
                # splines
                sp_batch = Spline_batch_LAN(dim=width[l+1],num=grid,k=k,noise_scale=noise_scale,scale_base=scale_base,scale_sp=scale_sp, base_fun=base_fun, scale_sp_trainable=scale_sp_trainable)
                self.act_fun.append(sp_batch)
            
            # linear
            linear = nn.Linear(width[l], width[l+1], bias=linear_bias)
            if l == 0:
                if mode == "function":
                    pass
                elif mode == "image":
                    linear.weight.data = linear.weight.data/np.sqrt(width[0])
            else:
                linear.weight.data = linear.weight.data/w0 * weight_init_scale
                #linear.weight.data = linear.weight.data * weight_init_scale
            self.linears.append(linear)
        
            # bias
            bias = nn.Linear(width[l+1],1,bias=False)
            bias.weight.data *= 0.  
            self.biases.append(bias)
        
        self.linears = nn.ModuleList(self.linears)
        self.biases = nn.ModuleList(self.biases)
        self.act_fun = nn.ModuleList(self.act_fun)
        
        self.grid = grid
        self.k = k
        self.base_fun = base_fun
    
    
    
    def initialize_from_another_model(self, another_model, x):
        
        another_model(x) # get activations
        batch = x.shape[0]
        
        self.initialize_grid_from_another_model(another_model, x)
        
        for l in range(self.depth-1):
            spb = self.act_fun[l]
            spb_parent = another_model.act_fun[l]
            preacts = another_model.spline_preacts[l]
            postsplines = another_model.spline_postsplines[l]
            spb.coef.data = curve2coef(preacts.permute(1,0), postsplines.permute(1,0), spb.grid, k=spb.k)
            spb.scale_base.data = spb_parent.scale_base.data
            spb.scale_sp.data = spb_parent.scale_sp.data

        for l in range(self.depth):
            self.linears[l].weight.data = another_model.linears[l].weight.data
            self.biases[l].weight.data = another_model.biases[l].weight.data
            # copy biases
        return self
    
    def update_grid_from_samples(self, x):
        for l in range(self.depth-1):
            self.forward(x)
            self.act_fun[l].update_grid_from_samples(self.acts[l])
        
    def initialize_grid_from_another_model(self, model, x):
        model(x)
        for l in range(self.depth-1):
            self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l])
        

    def forward(self, x):
        
        self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
        self.acts_verbose = [] # shape ([batch, n0, n1], [batch, n1, n2], ..., [batch, n_{L-1}, n_L])
        self.spline_preacts = []
        self.spline_postacts = []
        self.acts_scale = []
        self.weights_scale = []
        self.spline_postsplines = []
        self.acts_scale.append(torch.mean(torch.abs(x), dim=0))
        
        for l in range(self.depth):
            if l < self.depth - 1:
                x_postlinear = self.w0*self.linears[l](x)
                x, preacts, postacts, postspline = self.act_fun[l](x_postlinear)
            else:
                x = self.linears[l](x)
                
            x = x + self.biases[l].weight
            
            if l < self.depth - 1:
                self.acts.append(x_postlinear)
                self.acts_scale.append(torch.mean(torch.abs(postacts), dim=0))
                self.spline_preacts.append(preacts)
                self.spline_postacts.append(postacts)
                self.spline_postsplines.append(postspline)
                self.weights_scale.append(torch.mean(torch.abs(self.linears[l].weight)))
            else:
                self.weights_scale.append(torch.mean(torch.abs(self.linears[l].weight)))
                
        return x
    
    
    def plot(self, folder="./figures", beta_weight=1, beta_act = 10, mask=False):
        depth = len(self.width) - 1
        for l in range(depth-1):
            w_large = 2.0
            for i in range(self.width[l+1]):
                rank = torch.argsort(self.acts[l][:,i])
                plt.figure(figsize=(w_large,w_large))
                plt.gca().patch.set_edgecolor('black')  
                plt.gca().patch.set_linewidth(1.5)  
                #plt.axis('off')
                plt.xticks([])
                plt.yticks([])
                plt.plot(self.spline_preacts[l][:,i][rank].cpu().detach().numpy(), self.spline_postacts[l][:,i][rank].cpu().detach().numpy(), color="black", lw=5)
                plt.savefig(f'{folder}/sp_{l}_{i}.png', bbox_inches="tight", dpi=400)
                plt.close()
                
        def score2alpha(score, mode="act"):
            if mode == "act":
                return np.tanh(beta_act*score)
            elif mode == "weight":
                return np.tanh(beta_weight*score)
            #return (score > 0.05).astype(int)

        act_alpha = [score2alpha(self.acts_scale[i+1].cpu().detach().numpy(), mode="act") for i in range(len(self.acts_scale)-1)]

        weight_alpha = [score2alpha(torch.abs(self.linears[i].weight.data).cpu().detach().numpy(), mode="weight") for i in range(len(self.linears))]

    
        width = np.array(self.width)
        A = 1
        y0 = 0.4


        #plt.figure(figsize=(5,5*(neuron_depth-1)*y0))
        depth = len(width) - 1
        neuron_depth = depth + 1
        min_spacing = A/np.maximum(np.max(width),5)

        max_num_neuron_except_io = np.max(width[1:-1])
        max_num_neuron = np.max(width)
        y1 = 0.4/np.maximum(max_num_neuron_except_io,3)

        fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))


        # plot scatters
        for l in range(depth+1):
            n = width[l]
            spacing = A/n
            if l == 0 or l == depth:
                for i in range(n):
                    plt.scatter(1/(2*n)+i/n, l*y0, s=min_spacing**2*10000, color='black')        

            plt.xlim(0,1)
            plt.ylim(-0.1*y0, (neuron_depth-1+0.1)*y0)


        # plot lines
        for l in range(depth):
            n = width[l]
            n_next = width[l+1]
            spacing = A/n
            for i in range(n):
                for j in range(n_next):
                    if l == 0:
                        low_shift = 0.
                    else:
                        low_shift = y1
                    if l == depth -1:
                        up_shift = 0
                    else:
                        up_shift = y1
                    if mask == False:
                        plt.plot([1/(2*n)+i/n, 1/(2*n_next)+j/n_next], [l*y0+low_shift, (l+1)*y0-up_shift], color="red" if self.linears[l].weight.data[j,i] > 0 else "blue", alpha=weight_alpha[l][j][i])
                    else:
                        plt.plot([1/(2*n)+i/n, 1/(2*n_next)+j/n_next], [l*y0+low_shift, (l+1)*y0-up_shift], color="red" if self.linears[l].weight.data[j,i] > 0 else "blue", alpha=weight_alpha[l][j][i]*self.mask[l][i].item()*self.mask[l+1][j].item())



        # -- Transformation functions
        DC_to_FC = ax.transData.transform
        FC_to_NFC = fig.transFigure.inverted().transform
        # -- Take data coordinates and transform them to normalized figure coordinates
        DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))

        plt.axis('off')

        # plot splines
        for l in range(neuron_depth-2):
            n = width[l+1]
            for i in range(n):
                im = plt.imread(f'{folder}/sp_{l}_{i}.png')
                #im = plt.imread(f'{folder}/silu.png')
                left = DC_to_NFC([1/(2*n)+i/n-y1,0])[0]
                right = DC_to_NFC([1/(2*n)+i/n+y1,0])[0]
                bottom = DC_to_NFC([0,(l+1)*y0-y1])[1]
                up = DC_to_NFC([0,(l+1)*y0+y1])[1]
                newax = fig.add_axes([left,bottom,right-left,up-bottom])
                #newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
                if mask == True:
                    newax.imshow(im, alpha=act_alpha[l][i]*self.mask[l+1][i].item())
                else:
                    newax.imshow(im, alpha=act_alpha[l][i])
                newax.axis('off')
        print(fig)
        
    

    def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1 = 1., lamb_entropy = 1.0, lamb_weight_l1 = 1.0, lamb_coef = 0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, stop_grid_update_step=50, batch=-1):
        
        '''def reg(acts_scale):
            reg_ = 0.
            for i in range(len(acts_scale)):
                vec = (acts_scale[i].unsqueeze(dim=0) * torch.abs(self.linears[i].weight)).reshape(-1,)
                p = vec/torch.sum(vec)
                l1 = torch.sum(vec)
                entropy = - torch.sum(p*torch.log2(p+1e-4))
                reg_ += lamb_l1*l1 + lamb_entropy*entropy # both l1 and entropy


            # regularize coefficient to encourage spline to be zero
            for i in range(len(self.act_fun)):
                coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1))
                coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)),dim=1))
                reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1

            return reg_'''
        
        def reg(acts_scale):
            reg_ = 0.
            for i in range(len(acts_scale)):
                vec = acts_scale[i].reshape(-1,)
                p = vec/torch.sum(vec)
                reg_ += lamb_l1*torch.sum(vec) - lamb_entropy*torch.sum(p*torch.log2(p+1e-4)) # both l1 and entropy

            for i in range(len(self.linears)):
                reg_ += lamb_weight_l1 * torch.sum(torch.abs(self.linears[i].weight))

            return reg_

        pbar = tqdm(range(steps), desc='description')

        loss_fn = lambda x,y: torch.mean((x-y)**2)

        grid_update_freq = int(stop_grid_update_step/grid_update_num)

        if opt == "Adam":
            #optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        elif opt == "SGD":
            optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        elif opt == "LBFGS":
            optimizer = LBFGS(self.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
            
        results = {}
        results['train_loss'] = []
        results['test1_loss'] = []
        results['test2_loss'] = []
        results['reg'] = []

        if batch == -1 or batch > dataset['train_input'].shape[0]:
            batch_size = dataset['train_input'].shape[0]
        else:
            batch_size = batch

        for _ in pbar:


            train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
            test_id = np.random.choice(dataset['test1_input'].shape[0], batch_size, replace=False)

            if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
                self.update_grid_from_samples(dataset['train_input'][train_id[:1000]].to(device))


            if opt == "LBFGS":
                def closure():
                    optimizer.zero_grad()
                    pred_loss = loss_fn(self.forward(dataset['train_input'][train_id].to(device)), dataset['train_label'][train_id].to(device))
                    reg_ = reg(self.acts_scale)
                    objective = pred_loss + lamb*reg_
                    objective.backward()
                    return objective

            train_loss = loss_fn(self.forward(dataset['train_input'][train_id].to(device)), dataset['train_label'][train_id].to(device))
            reg_ = reg(self.acts_scale)
            loss = train_loss + lamb*reg_
            test1_loss = loss_fn(self.forward(dataset['test1_input'][test_id].to(device)), dataset['test1_label'][test_id].to(device))
            #test2_loss = loss_fn(model(dataset['test2_input']), dataset['test2_label'])

            if _ % log == 0:
                pbar.set_description(" %.2e | %.2e | %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test1_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))


            if opt == "Adam" or opt == "SGD":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            elif opt == "LBFGS":
                optimizer.step(closure)


            results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
            results['test1_loss'].append(torch.sqrt(test1_loss).cpu().detach().numpy())
            #results['test2_loss'].append(torch.sqrt(test2_loss).cpu().detach().numpy())
            results['reg'].append(reg_.cpu().detach().numpy())

        return results
    
    
    def prune(self, threshold=1e-2):
        mask = [torch.ones(self.width[0],)]
        active_neurons = [list(range(self.width[0]))]
        for i in range(len(self.acts_scale)-1):
            in_important = self.acts_scale[i+1]*torch.max(torch.abs(self.linears[i].weight), dim=1)[0] > threshold
            out_important = self.acts_scale[i+1]*torch.max(torch.abs(self.linears[i+1].weight), dim=0)[0] > threshold
            overall_important = in_important * out_important
            mask.append(overall_important.float())
            active_neurons.append(torch.where(overall_important==True)[0])
        active_neurons.append(list(range(self.width[-1])))
        mask.append(torch.ones(self.width[-1],))

        self.mask = mask


        model2 = LAN(copy.deepcopy(self.width), self.grid, self.k, base_fun=self.base_fun)
        model2.load_state_dict(self.state_dict())
        for i in range(len(self.acts_scale)):
            model2.biases[i].weight.data = model2.biases[i].weight.data[:,active_neurons[i+1]]
            model2.linears[i].weight.data = model2.linears[i].weight.data[active_neurons[i+1]][:,active_neurons[i]]
            if i < len(self.acts_scale) - 1:
                model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons[i+1])
            model2.width[i] = len(active_neurons[i])
            
        return model2
        
    

