#!/usr/bin/env python
# coding: utf-8

import os
import random
import time

import numpy as np
import pandas as pd
from sklearn import datasets
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

import copy

import pickle
from sklearn.datasets import fetch_california_housing
from sklearn import preprocessing


import scipy.stats
from scipy.signal import convolve2d, fftconvolve
import math
from math import ceil

import sys
import warnings

import torchvision
from torchvision.transforms import ToTensor, Normalize, Compose
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from collections import deque

from argparse import ArgumentParser

import einops
from einops import rearrange, repeat
from tqdm.auto import tqdm

if (len(sys.argv) != 4):
    print('Usage: average_entropy.py start_epoch epoch_step process_id')
    sys.exit()


start_epoch = int(sys.argv[1])
epoch_step = int(sys.argv[2])
process_id = int(sys.argv[3])

# This code was taken directly from Neel Nanda's study of grokking:
# https://colab.research.google.com/drive/1F6_1_cWXE5M7WocUcpQWp3v8z4b1jL20

class HookPoint(nn.Module):

    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []
    
    def give_name(self, name):
        # Called by the model at initialisation
        self.name = name
    
    def add_hook(self, hook, dir='fwd'):
        # Hook format is fn(activation, hook_name)
        # Change it into PyTorch hook format (this includes input and output, 
        # which are the same for a HookPoint)
        def full_hook(module, module_input, module_output):
            return hook(module_output, name=self.name)
        if dir=='fwd':
            handle = self.register_forward_hook(full_hook)
            self.fwd_hooks.append(handle)
        elif dir=='bwd':
            handle = self.register_backward_hook(full_hook)
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")
    
    def remove_hooks(self, dir='fwd'):
        if (dir=='fwd') or (dir=='both'):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir=='bwd') or (dir=='both'):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ['fwd', 'bwd', 'both']:
            raise ValueError(f"Invalid direction {dir}")
        
    def forward(self, x):
        return x

# Embed & Unembed
class Embed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_E = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_model))
    
    def forward(self, x):
        return torch.einsum('dbp -> bpd', self.W_E[:, x])

class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W_U = nn.Parameter(torch.randn(d_model, d_vocab)/np.sqrt(d_vocab))
    
    def forward(self, x):
        return (x @ self.W_U)

# Positional Embeddings
class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model):
        super().__init__()
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model)/np.sqrt(d_model))
    
    def forward(self, x):
        return x+self.W_pos[:x.shape[-2]]

# LayerNorm
class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon = 1e-4, model=[None]):
        super().__init__()
        self.model = model
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon
    
    def forward(self, x):
        if self.model[0].use_ln:
            x = x - x.mean(axis=-1)[..., None]
            x = x / (x.std(axis=-1)[..., None] + self.epsilon)
            x = x * self.w_ln
            x = x + self.b_ln
            return x
        else:
            return x

# Attention
class Attention(nn.Module):
    def __init__(self, d_model, num_heads, d_head, n_ctx, model):
        super().__init__()
        self.model = model
        self.W_K = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_Q = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_V = nn.Parameter(torch.randn(num_heads, d_head, d_model)/np.sqrt(d_model))
        self.W_O = nn.Parameter(torch.randn(d_model, d_head * num_heads)/np.sqrt(d_model))
        self.register_buffer('mask', torch.tril(torch.ones((n_ctx, n_ctx))))
        self.d_head = d_head
        self.hook_k = HookPoint()
        self.hook_q = HookPoint()
        self.hook_v = HookPoint()
        self.hook_z = HookPoint()
        self.hook_attn = HookPoint()
        self.hook_attn_pre = HookPoint()

    def forward(self, x):
        k = self.hook_k(torch.einsum('ihd,bpd->biph', self.W_K, x))
        q = self.hook_q(torch.einsum('ihd,bpd->biph', self.W_Q, x))
        v = self.hook_v(torch.einsum('ihd,bpd->biph', self.W_V, x))
        attn_scores_pre = torch.einsum('biph,biqh->biqp', k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (1 - self.mask[:x.shape[-2], :x.shape[-2]])
        attn_matrix = self.hook_attn(F.softmax(self.hook_attn_pre(attn_scores_masked/np.sqrt(self.d_head)), dim=-1))
        z = self.hook_z(torch.einsum('biph,biqp->biqh', v, attn_matrix))
        z_flat = einops.rearrange(z, 'b i q h -> b q (i h)')
        out = torch.einsum('df,bqf->bqd', self.W_O, z_flat)
        return out

# MLP Layers
class MLP(nn.Module):
    def __init__(self, d_model, d_mlp, act_type, model):
        super().__init__()
        self.model = model
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model)/np.sqrt(d_model))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp)/np.sqrt(d_model))
        self.b_out = nn.Parameter(torch.zeros(d_model))
        self.act_type = act_type
        # self.ln = LayerNorm(d_mlp, model=self.model)
        self.hook_pre = HookPoint()
        self.hook_post = HookPoint()
        assert act_type in ['ReLU', 'GeLU']
        
    def forward(self, x):
        x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x) + self.b_in)
        if self.act_type=='ReLU':
            x = F.relu(x)
        elif self.act_type=='GeLU':
            x = F.gelu(x)
        x = self.hook_post(x)
        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + self.b_out
        return x

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model):
        super().__init__()
        self.model = model
        # self.ln1 = LayerNorm(d_model, model=self.model)
        self.attn = Attention(d_model, num_heads, d_head, n_ctx, model=self.model)
        # self.ln2 = LayerNorm(d_model, model=self.model)
        self.mlp = MLP(d_model, d_mlp, act_type, model=self.model)
        self.hook_attn_out = HookPoint()
        self.hook_mlp_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()
    
    def forward(self, x):
        x = self.hook_resid_mid(x + self.hook_attn_out(self.attn((self.hook_resid_pre(x)))))
        x = self.hook_resid_post(x + self.hook_mlp_out(self.mlp((x))))
        return x

# Full transformer
class Transformer(nn.Module):
    def __init__(self, num_layers, d_vocab, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, use_cache=False, use_ln=True):
        super().__init__()
        self.cache = {}
        self.use_cache = use_cache

        self.embed = Embed(d_vocab, d_model)
        self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList([TransformerBlock(d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model=[self]) for i in range(num_layers)])
        #self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln

        for name, module in self.named_modules():
            if type(module)==HookPoint:
                module.give_name(name)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.pos_embed(x)
        for block in self.blocks:
            x = block(x)
        #x = self.ln(x)
        x = self.unembed(x)
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache
    
    def hook_points(self):
        return [module for name, module in self.named_modules() if 'hook' in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks('fwd')
            hp.remove_hooks('bwd')
    
    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()
        def save_hook_back(tensor, name):
            cache[name+'_grad'] = tensor[0].detach()
        for hp in self.hook_points():
            hp.add_hook(save_hook, 'fwd')
            if incl_bwd:
                hp.add_hook(save_hook_back, 'bwd')

def full_loss(model, data, device):
    loader = torch.utils.data.DataLoader(data, batch_size=len(data), shuffle=False)
    # Take the final position only
    x, labels = next(iter(loader))
    x = x.to(device)
    labels = labels.to(device)
    logits = model(x)[:, -1]
    return torch.nn.functional.cross_entropy(logits, labels)

def full_accuracy(model, data, device):
    loader = torch.utils.data.DataLoader(data, batch_size=len(data), shuffle=False)
    # Take the final position only
    x, labels = next(iter(loader))
    x = x.to(device)
    labels = labels.to(device)
    logits = model(x)[:, -1]
    predictions = torch.argmax(logits, dim=1)
    return torch.sum(predictions == labels).item() / len(labels)

seed = process_id
p = 67 #the modulus
fraction = 0.5
    
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#torch.set_default_dtype(torch.float64)

equals_token = p
x, y = torch.meshgrid(torch.arange(p), torch.arange(p), indexing='ij')
x = x.flatten()
y = y.flatten()
equals = torch.ones(x.shape, dtype=torch.int64) * equals_token
prompts = torch.stack([x, y, equals], dim=1).to(device)
answers = ((x + y) % p).to(device) #CHANGE it for other modular arithmetic tasks

data = torch.utils.data.TensorDataset(prompts, answers)
train, test = torch.utils.data.random_split(data, 
                                [int(fraction * len(data)),
                                len(data) - int(fraction * len(data))
                                ],generator=torch.Generator().manual_seed(42))

net = Transformer(num_layers=1, 
                    d_vocab=p+1, 
                    d_model=128, #In the paper, it is set to 32, 64, 128, corresponding to the results in Figures 2(a), 2(b), and 2(c).
                    d_mlp=512,   #In the paper, it is set to 128, 256, 512, corresponding to the results in Figures 2(a), 2(b), and 2(c).
                    d_head=32,
                    num_heads=4,
                    n_ctx=3, # context length
                    act_type='ReLU', 
                    use_cache=False, 
                    use_ln=False # use LayerNorm
                ).to(device)

bin_begin=-30
bin_end=30
bin_width=0.02
gaussianSigma=5*bin_width
gaussianCutoff=4*gaussianSigma
margin=0
time_step=1e-4
smoothed_accuracy_slope=7 
x_loss_bias=0
kT = 1.0
frictionCoefficient = 1e-2
goldilocks_norm=30 #constrain the weight norm; if not need, set it into 0.
minTrainLossToStudy = -20.0
maxTrainLossToStudy = 1.0
forbiddenRegionEntropyIncrease = 100/time_step

maxWLFactor=100
maxWLFactorStep=1e7

bin_num = int((bin_end-bin_begin)/bin_width+2*margin+2)

binEdges = (np.array(range(bin_num))-margin)*bin_width+bin_begin
xx, yy = np.meshgrid(binEdges, binEdges)
if process_id == 0 and start_epoch == 0:
    np.savez_compressed("../XAndY_optimized.npz", x=xx, y=yy)

entropyBias = forbiddenRegionEntropyIncrease*(xx>maxTrainLossToStudy)*(xx-maxTrainLossToStudy)*(xx-maxTrainLossToStudy)

if os.path.isfile('../ave_entropy.pickle'):
    with open('../ave_entropy.pickle', 'rb') as pickled_file:
        entropy = pickle.load(pickled_file)
else:
    entropy = torch.zeros((bin_num, bin_num))

if os.path.isfile('velocity.pickle'):
    with open('./velocity.pickle', 'rb') as pickled_file:
        velocity = pickle.load(pickled_file)
else:
    velocity = None

histogram = np.zeros((bin_num, bin_num))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

nParameters = sum(p.numel() for p in net.parameters() if p.requires_grad)

print('Total number of parameters is: ' + str(nParameters))

langevinForceStddev = np.sqrt(frictionCoefficient * (2-frictionCoefficient) * kT)

'''
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
'''
if os.path.isfile('net.pth'):
    net.load_state_dict(torch.load('./net.pth'))
else:
    net.to(torch.float32)
    net.apply(init_weights)
    
net = net.to(device)

loss = nn.CrossEntropyLoss()

ts='{:.2e}'.format(time_step)
ei='{:.2e}'.format(forbiddenRegionEntropyIncrease)
temperature='{:.2e}'.format(kT)
print(f"{p},{temperature},{ts},{frictionCoefficient},{ei},{maxTrainLossToStudy}")

def Train_WLMD(net, train_WL, test_WL, entropy, entropyBias, histogram, numepochs, bin_num, bin_width, start_epoch, goldilocks_norm, time_step=None, velocity=None):
   
    start_epoch=start_epoch+1
    print('Starting step is: ' + str(start_epoch))
    entropy += entropyBias

    E_traj = []

    losses_train = []
    losses_test = []

    gau_1dloc=np.arange(-gaussianCutoff, 1.0000001*gaussianCutoff, bin_width)
    x, y=np.meshgrid(gau_1dloc, gau_1dloc)
    gau_dist = scipy.stats.multivariate_normal.pdf(np.dstack((x, y)), (0, 0), gaussianSigma**2) * bin_width
    gau_dist=gau_dist/np.sum(gau_dist)
    gau_dist=torch.tensor(gau_dist)
    gau_dist=gau_dist.to(device)
    entropy=entropy.to(device)
    detailQueue=deque()

    if velocity is None:
        velocity = []
        for param in net.parameters():
            velocity.append(0 * param.data)

        for vi in range(len(velocity)):
            velocity[vi] = torch.normal(0, np.sqrt(kT), velocity[vi].shape)
            velocity[vi] = velocity[vi].to(device)
        vi=0
        for param in net.parameters():
            velocity[vi]*=np.sqrt(param.pow(2).sum().item())
            vi=vi+1
    

    if time_step is None:
        print("Error: time_step is not set. This version of program no longer has a default time_step.")
        exit(1)

    def SL_grad(S, x_WL, y_WL, bin_width, bin_num):
        bin_id_x, bin_id_y = bin_id_cal_meta(x_WL, y_WL, bin_width, bin_num)
        y_grad = (S[bin_id_y+1][bin_id_x]-S[bin_id_y-1][bin_id_x]) / bin_width / 2
        x_grad = (S[bin_id_y][bin_id_x+1]-S[bin_id_y][bin_id_x-1]) / bin_width / 2
        return torch.as_tensor(x_grad, device=device), torch.as_tensor(y_grad, device=device)

    def bin_id_cal_meta(x_WL, y_WL, bin_width, bin_num):
        bin_id_x = torch.div(x_WL-bin_begin, bin_width, rounding_mode='floor')+margin
        bin_id_y = torch.div(y_WL-bin_begin, bin_width, rounding_mode='floor')+margin
        return int(bin_id_x), int(bin_id_y)

    def histogram_updates(histogram, x_WL, y_WL, bin_width, bin_num):
        bin_id_x, bin_id_y = bin_id_cal_meta(x_WL, y_WL, bin_width, bin_num)
        histogram[bin_id_y][bin_id_x] += 1
        return histogram

    def entropy_updates_meta(entropy, x_WL, y_WL, gau_dist, scale_factor, bin_width, bin_num):
        width = gau_dist.shape[0] // 2
        bin_id_x, bin_id_y = bin_id_cal_meta(x_WL, y_WL, bin_width, bin_num)
        entropy[bin_id_y-width:bin_id_y+width+1, bin_id_x-width:bin_id_x+width+1] += scale_factor * gau_dist
        return entropy

    minKineticEnergy = 1e30
    maxKineticEnergy = 0
    onceUnderBound=0
    if_reflect=0
    sgm=torch.nn.Sigmoid()
    for epoch in range(start_epoch, start_epoch+numepochs):
        E_total = []
        for vel in velocity:  
            E_total.append((vel**2).sum())
        kineticEnergy=0.5*sum(E_total)
        minKineticEnergy = min(kineticEnergy, minKineticEnergy)
        maxKineticEnergy = max(kineticEnergy, maxKineticEnergy)

        #Scaling the parameters to constrain the weight norm in the goldilocks zone if needed
        regularization_loss=np.sqrt(sum(param.pow(2).sum().item() for param in net.parameters()))
        if goldilocks_norm!=0:
            for param in net.parameters():
                param.data *= min(1,goldilocks_norm/regularization_loss)
        loss_train=full_loss(net, train, device)

        ln_loss_train = torch.log(loss_train+x_loss_bias)
        accuracy_train = full_accuracy(net, train, device)
        
        
        loss_val=full_loss(net, test, device)
        ln_loss_val = torch.log(loss_val)
        accuracy_val = full_accuracy(net, test, device)
        
        loader = torch.utils.data.DataLoader(test, batch_size=len(test), shuffle=False)
        # Take the final position only
        x, labels = next(iter(loader))
        x = x.to(device)
        labels = labels.to(device)
        logits = net(x)[:, -1]
    
        y_hat_val = logits
        val_y=labels
        
        value, index = torch.topk(input=y_hat_val, k=2, dim=1)
        correct = (val_y == index[:,0])
        highestWrongLogit = value[:,0] + correct*(value[:,1]-value[:,0]) # the highest logit among all wrong classes. If correct, then highestWrongLogit=value[:,1], otherwise, highestWrongLogit=value[:,0]
        correctLogit = y_hat_val[torch.arange(y_hat_val.size(0)), val_y]
        logit_diff=correctLogit - highestWrongLogit
        smoothed_accuracy=torch.mean(sgm(smoothed_accuracy_slope*logit_diff))
        smoothing_error=smoothed_accuracy-accuracy_val
        
        x_WL=ln_loss_train
        y_WL=10*smoothed_accuracy
        
        if (minTrainLossToStudy<x_WL<maxTrainLossToStudy):
            onceUnderBound=1  
            if_reflect=0
        else:
            if_reflect=onceUnderBound
            onceUnderBound=0
        
        X_grad, Y_grad = SL_grad(entropy, x_WL, y_WL, bin_width, bin_num)
        combined_loss=-1*X_grad*x_WL-1*Y_grad*y_WL
        #x_loss=-x_WL
        net.zero_grad()
        combined_loss.backward()
        
        with torch.no_grad():
            m = 0
            if(if_reflect==0):
                for param in net.parameters():
                        if(1):   
                            velocity[m] += param.grad * time_step
                            velocity[m] -= frictionCoefficient*velocity[m] + langevinForceStddev*torch.randn_like(velocity[m], device=device)
                            param += velocity[m] * time_step 
                        m += 1               
            
            else:
                for param in net.parameters():
                    #reflect
                    #normalize the grad
                    if (1):
                        norm_grad=torch.div(param.grad,torch.norm(param.grad, p = 2, dim = -1, keepdim = True))
                        #count delta velocity
                        delta_v=2*norm_grad*torch.sum(norm_grad*velocity[m], dim = -1, keepdim = True)
                        velocity[m]-=delta_v
                        velocity[m] -= frictionCoefficient*velocity[m] + langevinForceStddev*torch.randn_like(velocity[m], device=device)
                        velocity[m][velocity[m]!=velocity[m]]=0.0
                        param += velocity[m] * time_step
                    m += 1   
            

        if epoch < maxWLFactorStep:
            scale_factor = maxWLFactor / maxWLFactorStep * epoch
        else:                       # factor should scale as 1/t
            scale_factor = maxWLFactor

        histogram = histogram_updates(histogram, x_WL, y_WL, bin_width, bin_num)
        entropy = entropy_updates_meta(entropy, x_WL, y_WL, gau_dist, scale_factor, bin_width, bin_num)

        detail = f"{x_WL}, {accuracy_val}, {regularization_loss}"
        detailQueue.append(detail)
        if len(detailQueue)>100:
            detailQueue.popleft()
        
        
        #print(detail)
        if epoch % 1000 == 0:
            E_traj.append(kineticEnergy)
            losses_train.append(loss_train.item())
            losses_test.append(loss_val.item())
            
            minKineticEnergy = 1e30
            maxKineticEnergy = 0
            if epoch % 1000000 == 0 and process_id == 0:
                temp=np.copy(entropy.cpu())
                temp-=entropyBias
                temp[temp<1e-5]=0
                np.savez_compressed("../entropy"+str(epoch//1000000)+"M.npz", entropy=temp, histogram=histogram)
            print(detail)
    entropy=entropy.to('cpu')
    entropy -= entropyBias     
    return losses_train, losses_test, net, entropy, histogram, velocity, E_traj



train_loss_sum, test_loss_sum, net, entropy, histogram, velocity, E_traj = Train_WLMD(net, train, test, entropy, entropyBias, histogram, epoch_step, bin_num, bin_width, start_epoch, goldilocks_norm, time_step=time_step, velocity=velocity)

with open('./entropy.pickle', 'wb') as pickle_file:
    pickle.dump(entropy,  pickle_file)


# In[54]:


with open('./histogram.pickle', 'wb') as pickle_file:
    pickle.dump(histogram,  pickle_file)


# In[55]:


with open('./velocity.pickle', 'wb') as pickle_file:
    pickle.dump(velocity,  pickle_file)


# In[56]:


with open('./E_traj.pickle', 'wb') as pickle_file:
    pickle.dump(E_traj,  pickle_file)


# In[57]:


with open('./train_loss_sum.pickle', 'wb') as pickle_file:
    pickle.dump(train_loss_sum,  pickle_file)


# In[58]:


with open('./test_loss_sum.pickle', 'wb') as pickle_file:
    pickle.dump(test_loss_sum,  pickle_file)


# In[60]:


torch.save(net.state_dict(), './net.pth')

