import torch
import pickle
import argparse
import os
import csv
import random
import copy
import logging
from dataset import LongRangeDataset
from models import TransformerLanguageModel, LSTMLanguageModel, LRULanguageModel, MambaLanguageModel
from utils import ModelHandler
from datetime import datetime
import numpy as np

import faulthandler

faulthandler.enable()

def config_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)

def train(data_file, model_type="lstm", embed_dim=128, num_layers=2, 
    lr=1e-3, hidden_size=256, ff_dim=512, num_heads=4, batch_size=64, 
    save_location='./models/', r_min=0.9, r_max=0.999,epochs=500,
    rnd_seed=None, data_percent=1.0, force_cpu=False, no_attention=True,
    pos_enc=False, d_conv=4, d_state=16, expand=2):
    if rnd_seed is not None:
        seed = rnd_seed
    else:
        seed = random.randint(1, 10000)

    config_seed(seed)

    logging.basicConfig(filename='training.log', filemode='a', 
        format='%(asctime)s - %(message)s', level=logging.INFO)

    dataset = LongRangeDataset(data_file, data_percent)
    print(len(dataset.all_data['train']))

    if force_cpu:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    if model_type == "lstm":
        model = LSTMLanguageModel(embed_dim, dataset.vocab_size, 
            hidden_size, num_layers, pad_idx=dataset.pad_idx, 
            device=device, use_attention=not no_attention).to(device)
        add_model_info = ["hidden_size_{}".format(hidden_size)]
        if no_attention:
            add_model_info.append("no_attention")
        else:
            add_model_info.append("with_attention")
    elif model_type == "lru":
        model = LRULanguageModel(embed_dim, dataset.vocab_size, 
            hidden_size, num_layers, pad_idx=dataset.pad_idx, 
            device=device, r_min=r_min, r_max=r_max,
            use_attention=not no_attention).to(device)
        add_model_info = ["hidden_size_{}".format(hidden_size), 
                            "r_min_{:.4f}".format(r_min), 
                            "r_max_{:.4f}".format(r_max)]
        if no_attention:
            add_model_info.append("no_attention")
        else:
            add_model_info.append("with_attention")
    elif model_type == "transformer":
        model = TransformerLanguageModel(embed_dim, dataset.vocab_size, 
            num_layers, num_heads, ff_dim, pad_idx=dataset.pad_idx,
            device=device, pos_enc=pos_enc).to(device)
        add_model_info = ["heads_{}".format(num_heads), 
                            "ff_dim_{}".format(ff_dim)]
        if pos_enc:
            add_model_info.append("pos_enc")
        else:
            add_model_info.append("no_pos_enc")
    elif args.model_type == "mamba":
        model = MambaLanguageModel(embed_dim, dataset.vocab_size, 
            d_conv=d_conv, d_state=d_state, 
            expand=expand, pad_idx=dataset.pad_idx, 
            device=device).to(device)
        add_model_info = ["d_conv_{}".format(d_conv), 
                            "d_state_{}".format(d_state), 
                            "expand_{}".format(expand)]

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("Model with {} parameters".format(num_params))

    optimise = torch.optim.Adam(model.parameters(), lr=lr)

    model_info = [model_type, data_file.split("/")[-1].split(".")[0],
        "parameters_{}".format(num_params), "embed_{}".format(embed_dim), 
        "layers_{}".format(num_layers), "lr_{:.5f}".format(lr), 
        "seed_{}".format(seed), "batch_size_{}".format(batch_size
            )] + add_model_info
    model_name = '{}.model'.format('___'.join(model_info))


    curr_dt = datetime.now()
    train_id = int(round(curr_dt.timestamp()))
    logging.info("ID: {}, beginning training on model {}".format(train_id, 
        model_name))

    handler = ModelHandler(model, optimise, dataset, device, save_location, 
        model_name)
    handler.train_model(batch_size, epochs)

    handler.load_model(os.path.join(save_location, model_name), 
        batch_size)
    test_acc, results = handler.test_model("beam_search", split="test", 
        record_results=True)
    val_acc, _ = handler.test_model("beam_search", split="val", record_results=True)

    pickle.dump(results, open('results/{}_results.pickle'.format(model_name),'wb'))
    return test_acc

class BayesOptim:
    def __init__(self, all_combos, maxes, device, no_attention, 
        pos_enc, data_file, model_type):
        self.data_file = data_file
        self.model_type = model_type
        self.all_combos = all_combos
        self.maxes = maxes
        self.device = device
        self.no_attention = no_attention
        self.pos_enc = pos_enc
        self.all_combos.to(self.device)
        self.calculate_K_all()
        self.not_tried = torch.tensor([a for a in range(
            all_combos.shape[0])]).int().to(self.device)
        self.tried = torch.tensor([]).int().to(self.device)
        self.f = torch.tensor([]).to(self.device)
        self.results = []

    def covar(self, x_i, x_j):
        return torch.exp(-0.5*(torch.linalg.vector_norm(x_i-x_j)**2)) 

    def calculate_K_all(self):
        scaled_combos = self.all_combos.clone()
        for i in range(len(self.maxes)):
            scaled_combos[:,i] = scaled_combos[:,i] / self.maxes[i]
        N = scaled_combos.shape[0]
        self.K_all = torch.zeros((N,N)).to(self.device)
        for i in range(N):
            for j in range(N):
                self.K_all[i,j] = self.covar(scaled_combos[i, :], 
                    scaled_combos[j, :])

    def maximise_acquisition_function(self, kappa=0.5):
        K = self.K_all[self.tried,:]
        K = K[:, self.tried]

        k = self.K_all[self.not_tried, :]
        k = k[:, self.tried]

        mu = torch.mm(torch.mm(k, torch.linalg.inv(K)), self.f.unsqueeze(1)).squeeze(-1)
        var = (-1 * torch.bmm(torch.mm(k, torch.linalg.inv(K)).unsqueeze(1), 
            k.unsqueeze(-1)) + 1).squeeze(-1).squeeze(-1)
        scores = mu + (kappa * var)
        values, indices =torch.max(scores, dim=0)
        idx = indices.item()
        to_try = self.not_tried[idx].item()
        hyperparams = self.all_combos[to_try]
        self.tried = torch.concat([self.tried, 
                self.not_tried[idx].unsqueeze(0)], dim=0)
        self.not_tried = torch.concat(
                [self.not_tried[:idx], self.not_tried[idx+1:]], dim=0)
        return hyperparams, to_try

    def initialise(self):
        idx = random.randint(0, self.not_tried.shape[0]-1)
        trying = self.not_tried[idx]
        hyperparams = self.all_combos[trying]
        print(hyperparams)
        self.tried = torch.concat([self.tried, 
            self.not_tried[idx].unsqueeze(0)], dim=0)
        self.not_tried = torch.concat(
            [self.not_tried[:idx], self.not_tried[idx+1:]], dim=0)
        if self.model_type == "lstm":
            score = train(self.data_file, model_type = self.model_type, 
                embed_dim=int(hyperparams[0].item()), num_layers = int(hyperparams[1].item()),
                lr = hyperparams[2].item(), batch_size = int(hyperparams[3].item()),
                hidden_size=int(hyperparams[4].item()), no_attention = self.no_attention)
        elif self.model_type == "transformer":
            score = train(self.data_file, model_type = self.model_type, 
                embed_dim=int(hyperparams[0].item()), num_layers = int(hyperparams[1].item()),
                lr = hyperparams[2].item(), batch_size = int(hyperparams[3].item()),
                ff_dim=int(hyperparams[4].item()), 
                num_heads=int(hyperparams[5].item()),
                pos_enc=self.pos_enc)
        elif self.model_type == "lru":
            score = train(self.data_file, model_type = self.model_type, 
                embed_dim=int(hyperparams[0].item()), num_layers = int(hyperparams[1].item()),
                lr = hyperparams[2].item(), batch_size = int(hyperparams[3].item()),
                hidden_size=int(hyperparams[4].item()), no_attention = self.no_attention)
        elif self.model_type == "mamba":
            score = train(self.data_file, model_type = self.model_type, 
                embed_dim=int(hyperparams[0].item()),
                lr = hyperparams[1].item(), batch_size = int(hyperparams[2].item()),
                d_state=int(hyperparams[3].item()), 
                d_conv = int(hyperparams[4].item()),
                expand=int(hyperparams[5].item()))

        self.f = torch.concat([self.f, 
            torch.tensor(score).unsqueeze(0).to(self.device)], dim=0)
        self.results.append([trying.item(), score])
        print(self.results)

    def optimise(self, num_samples=5):
        hyperparams, trying = self.maximise_acquisition_function()
        print(hyperparams)
        if self.model_type == "lstm":
            score = train(self.data_file, model_type = self.model_type, 
                embed_dim=int(hyperparams[0].item()), num_layers = int(hyperparams[1].item()),
                lr = hyperparams[2].item(), batch_size = int(hyperparams[3].item()),
                hidden_size=int(hyperparams[4].item()), no_attention = self.no_attention)
        elif self.model_type == "transformer":
            score = train(self.data_file, model_type = self.model_type, 
                embed_dim=int(hyperparams[0].item()), num_layers = int(hyperparams[1].item()),
                lr = hyperparams[2].item(), batch_size = int(hyperparams[3].item()),
                ff_dim=int(hyperparams[4].item()), 
                num_heads=int(hyperparams[5].item()),
                pos_enc=self.pos_enc)
        elif self.model_type == "lru":
            score = train(self.data_file, model_type = self.model_type, 
                embed_dim=int(hyperparams[0].item()), num_layers = int(hyperparams[1].item()),
                lr = hyperparams[2].item(), batch_size = int(hyperparams[3].item()),
                hidden_size=int(hyperparams[4].item()), no_attention = self.no_attention)
        elif self.model_type == "mamba":
            score = train(self.data_file, model_type = self.model_type, 
                embed_dim=int(hyperparams[0].item()),
                lr = hyperparams[1].item(), batch_size = int(hyperparams[2].item()),
                d_state=int(hyperparams[3].item()), 
                d_conv = int(hyperparams[4].item()),
                expand=int(hyperparams[5].item()))
        self.f = torch.concat([self.f, 
                torch.tensor(score).unsqueeze(0).to(self.device)], dim=0)
        self.results.append([trying, score])
        print(self.results)

parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, help="Model to use", 
    default='')
parser.add_argument('--filename', type=str, help="Data file to use", 
    default='')
parser.add_argument('--no_attention', action='store_true', 
    help="LSTM or LRU without attention")
parser.add_argument('--pos_enc', action='store_true', 
    help="Add position encoding for transformers")
args = parser.parse_args()
model_type = args.model_type
filename = args.filename
no_attention = args.no_attention
pos_enc = args.pos_enc

embed_dim = [32, 64, 128]
num_layers = [1, 2, 4]
lr = [np.exp(l) for l in list(
        np.arange(np.log(0.0001), np.log(0.015), 0.75))]
batch_size = [32, 64, 128, 256]
d_state = [8,16, 32, 64]
d_conv = [2, 4, 8]
expand = [2,4]

if model_type == "lstm":
    hidden_size = [32, 64, 128, 256]
    hyperparam_list = [embed_dim, num_layers, lr, batch_size, hidden_size]
elif model_type == "transformer":
    ff_dim = [256, 512]
    num_heads = [4, 8]
    hyperparam_list = [embed_dim, num_layers, lr, batch_size, ff_dim, num_heads]
elif model_type == "lru":
    r_min = 0.9
    r_max = 0.999
    hidden_size = [32, 64, 128, 256]
    hyperparam_list = [embed_dim, num_layers, lr, batch_size, hidden_size]
elif model_type == "mamba":
    hyperparam_list = [embed_dim, lr, batch_size, d_state, d_conv, expand]
all_combos = [[]]
for hyperparam in hyperparam_list:
    new_combos = []
    for h in hyperparam:
        combo_copy = copy.deepcopy(all_combos)
        for c in combo_copy:
            c.append(h)
        new_combos += combo_copy
    all_combos = new_combos
hyperparams = torch.tensor(all_combos)
maxes = [max(a) for a in hyperparam_list]
print("combos done")
model_name = model_type
if model_name in ["lstm", "lru"]:
    model_name += "_attention_" + str(not no_attention)
elif model_name == "transformer":
    model_name += "_pos_enc_" + str(pos_enc)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if os.path.isfile("optim/"+model_name+"_"+filename+".optim"):
    optim = pickle.load(open("optim/"+model_name+"_"+filename+".optim", "rb"))
else:
    optim = BayesOptim(hyperparams, maxes, device, no_attention, 
        pos_enc, data_file = "data/"+filename+".data", 
        model_type=model_type)
pickle.dump(optim, open("optim/"+model_name+"_"+filename+".optim", "wb"))
print("optim created")
while len(optim.results) < 3:
    optim.initialise()
    pickle.dump(optim, open("optim/"+model_name+"_"+filename+".optim", "wb"))
for i in range(10):
    optim.optimise()
    pickle.dump(optim, open("optim/"+model_name+"_"+filename+".optim", "wb"))