import os
from random import randint
import uuid
import math

from quinine import QuinineArgumentParser
from quinine import Quinfig
from tqdm import tqdm
import torch
import torch.nn as nn
import yaml

from eval import get_run_metrics
from tasks import get_task_sampler
from samplers import get_data_sampler
from curriculum import Curriculum
from schema import schema
from models import build_model
import numpy as np

import wandb

torch.backends.cudnn.benchmark = True


INPUT_DIM = 5
BATCH_SIZE = 64
NUM_LAYERS = 12
EMBED_DIM = 256
NUM_SKILLS = 4
SEQ_LEN = 100
# change this to 1 for 1-dim testing
OUTPUT_DIM = 5
NUM_SUBSKILLS = 4
MIN_C = 3
NUM_TRAINING_STEPS = 200000
LOAD_PRETRAINED_MODEL = True

# NOTE: choose number of xs based on max seq_len
# because the resulting IC prompt will be of length (C+1)*num_xs
# this means that num_xs <= SEQ_LEN/(C+1)
NUM_XS = SEQ_LEN//(MIN_C+1)

args = Quinfig(config_path='conf/linear_regression.yaml')
# update args with schema
args.model.n_dims = INPUT_DIM
args.model.zero_pad_embed = False
args.training.num_tasks = None
args.n_dims = INPUT_DIM
args.training.batch_size = BATCH_SIZE
args.training.curriculum.dims.start = INPUT_DIM
args.training.curriculum.dims.end = INPUT_DIM
args.training.curriculum.points.start = NUM_XS
args.training.curriculum.points.end = NUM_XS
args.model.n_positions = SEQ_LEN
curriculum_n_points = NUM_XS
curriculum_n_dims_truncated = INPUT_DIM
args.model.n_embd = EMBED_DIM
args.training.learning_rate = 0.001
args.model.hidden_layer_size = 0
args.model.apply_cot = False

model = build_model(args.model)
model.cuda()
model.train()

starting_step = 0

n_dims = model.n_dims
bsize = args.training.batch_size
data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
task_sampler = get_task_sampler(
    args.training.task,
    n_dims,
    bsize,
    num_tasks=args.training.num_tasks,
    **args.training.task_kwargs,
)
pbar = tqdm(range(starting_step, args.training.train_steps))

i = 0
data_sampler_args = {}
task_sampler_args = {}


xs = data_sampler.sample_xs(
    curriculum_n_points,
    bsize,
    curriculum_n_dims_truncated,
    **data_sampler_args,
)

task = task_sampler(**task_sampler_args)

def normalize_matrix(A):
    return A/torch.norm(A)

# function list
# skill_set = [None, torch.Tensor((1, 1, 1, 1)), torch.Tensor((1, -1, 1, -1))]
skill_set = [[normalize_matrix(torch.randn(INPUT_DIM, OUTPUT_DIM)) for i in range(NUM_SUBSKILLS)]
                for i in range(NUM_SKILLS)]

def one_hot(i=0, d=4):
    onehot_vec = torch.zeros(d)
    onehot_vec[i] = 1
    return onehot_vec

# specifically designed for tensors of the shape [bsize, 1, dim]
def round_to_onehot(t):
    bsize, num_tokens, dim = t.shape
    if num_tokens > 1:
        print("Error! Round to onehot not designed for multiple tokens")
        return -1
    onehot_vec = torch.zeros_like(t)
    for row_idx in range(bsize):
        onehot_vec[row_idx, 0, torch.argmax(t[row_idx])] = 1.0
    return onehot_vec

def get_chain(xs, C=3, subskill_ids=None):
    outputs = [xs]
    skill_ids = []
    bsize, num_tokens, dim = xs.shape
    for idx in range(NUM_SKILLS):
        subskill_id = subskill_ids[idx]
        skill_id = torch.zeros_like(xs)
        skill_id[:, :] = one_hot(idx, dim)
        skill_ids.append(skill_id)
        outputs.append(outputs[idx-1] @ skill_set[idx][subskill_id])
    
    # now trim the chain based on C and interleave
    trimmed_num_outputs = (C+1)//2
    trimmed_num_skills = (C-1)//2
    outputs = outputs[:trimmed_num_outputs]
    skill_ids = skill_ids[:trimmed_num_skills]
    # add stop_token = 0
    skill_ids.append(torch.zeros_like(xs))

    stacked_tensors = []
    for idx, output in enumerate(outputs):
        stacked_tensors.append(output)
        if idx < len(skill_ids):
            stacked_tensors.append(skill_ids[idx])

    # interleave the skills and xs
    zs = torch.stack(stacked_tensors, dim=2)
    zs = zs.view(bsize, len(stacked_tensors) * num_tokens, dim)

    # update dims
    bsize, num_tokens, dim = zs.shape

    # trim it down
    if num_tokens > SEQ_LEN:
        zs = zs[:, :SEQ_LEN, :]
    # delete the incomplete chains
    if (SEQ_LEN%(C+1)) > 0:
        zs[:, -(SEQ_LEN%(C+1)):, :] = torch.zeros_like(zs[:, -(SEQ_LEN%(C+1)):, :])

    return zs

def create_sequences(xs, C=None):
    # first make a zs and ys of the "correct size"
    zs = torch.zeros(BATCH_SIZE, SEQ_LEN, INPUT_DIM)
    ys = torch.zeros_like(zs)

    for batch_idx in range(BATCH_SIZE):
        # pick chain length C => if C=5, 2 functions, if C=3, 1 function
        if C is None:
            C = np.random.choice([(2*c)+1 for c in range(1, NUM_SKILLS+1)])
        subskill_ids = [np.random.randint(NUM_SUBSKILLS) for i in range(NUM_SKILLS)]
        # do this batchwise so that C,subskill_ids is chosen independently
        # for each prompt
        zs[batch_idx:batch_idx+1, :, :] = get_chain(xs[batch_idx:batch_idx+1, :, :], C=C, subskill_ids=subskill_ids)

    # computing the target should be super easy.
    # just roll the xs?
    # ys[batch_idx:batch_idx+1, :, :] = torch.roll(zs_0, -1, dims=1)
    ys = torch.roll(zs, -1, dims=1)
    ys[:, -1] = torch.zeros_like(ys[:, -1])
    return zs, ys

zs, ys = create_sequences(xs)

# now just train on this
optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
loss_func = nn.MSELoss()

if not LOAD_PRETRAINED_MODEL:
    for step in range(NUM_TRAINING_STEPS):
        xs = data_sampler.sample_xs(
            curriculum_n_points,
            bsize,
            curriculum_n_dims_truncated,
            **data_sampler_args,
        )

        zs, ys = create_sequences(xs)

        optimizer.zero_grad()
        output = model(zs.cuda(), None)
        loss = loss_func(ys.cuda(), output.cuda())
        loss.backward()
        optimizer.step()

        if (step % 100 == 0):
            print("Step={} | Loss = {}".format(step, loss.item()))

        if (step % 100000 == 0):
            torch.save(model.state_dict(), "skill_cot_model_step_{}.pt".format(step))


    torch.save(model.state_dict(), "skill_cot_model.pt")
else:
    ckpt = torch.load("skill_cot_model.pt", map_location=f"cuda")
    model.load_state_dict(ckpt)


# evaluate model
# TODO: repeat this for each num_ic_examples
# for now, doing only for the final test_example

def evaluate_model(model, chain_lengths=[3, 5, 7], print_debug=False):
    test_error = {}
    for chain_length in chain_lengths:
        test_xs = data_sampler.sample_xs(
                curriculum_n_points,
                BATCH_SIZE,
                curriculum_n_dims_truncated,
                **data_sampler_args,
            )

        test_zs, test_ys = create_sequences(test_xs, C=chain_length)

        # start of final chain
        test_x_idx = (SEQ_LEN%(chain_length+1) + (chain_length+1))
        final_prediction = torch.zeros_like(test_zs[:, -test_x_idx:])

        for idx in range(chain_length):
            test_output = model(test_zs.cuda(), None)
            curr_idx = -(test_x_idx-idx)
            # if it is stop token, it is not a skill
            if (idx % 2 == 0) and (idx != chain_length - 1):
                # if skill, turn it into one-hot, else leave
                skill_preds = test_output[:, curr_idx:(curr_idx+1)]
                next_token_pred = round_to_onehot(skill_preds)
            else:
                next_token_pred = test_output[:, curr_idx:(curr_idx+1)]

            test_zs[:, (curr_idx+1):(curr_idx+2)] = next_token_pred
            final_prediction[:, idx:(idx+1)] = next_token_pred

        # compare final_prediction and test_ys
        if print_debug:
            print("Chain length: {}".format(chain_length))
            print("Final prediction: {}".format(final_prediction[0]))
            print("Target output: {}".format(test_ys[0, -test_x_idx:]))

        test_error[chain_length] = loss_func(test_ys[:, -test_x_idx:].cuda(), final_prediction.cuda())
    print(test_error)
    return test_error

test_error_dict = evaluate_model(model, chain_lengths=[3, 5, 7, 11], print_debug=True)


