import os
from random import randint
import uuid
import math

from quinine import QuinineArgumentParser
from quinine import Quinfig
from tqdm import tqdm
import torch
import torch.nn as nn
import yaml

from eval import get_run_metrics
from tasks import get_task_sampler
from samplers import get_data_sampler
from curriculum import Curriculum
from schema import schema
from models import build_model
import numpy as np

import wandb

torch.backends.cudnn.benchmark = True


INPUT_DIM = 4
BATCH_SIZE = 64
NUM_LAYERS = 12
EMBED_DIM = 256
NUM_FUNCTIONS = 2
SEQ_LEN = 20
# change this to 1 for 1-dim testing
OUTPUT_DIM = 4

args = Quinfig(config_path='conf/linear_regression.yaml')
# update args with schema
args.model.n_dims = INPUT_DIM
args.model.zero_pad_embed = False
args.training.num_tasks = None
args.n_dims = INPUT_DIM
args.training.batch_size = BATCH_SIZE
args.training.curriculum.dims.start = INPUT_DIM
args.training.curriculum.dims.end = INPUT_DIM
args.training.curriculum.points.start = SEQ_LEN
args.training.curriculum.points.end = SEQ_LEN
args.model.n_positions = SEQ_LEN
curriculum_n_points = SEQ_LEN
curriculum_n_dims_truncated = INPUT_DIM
args.model.n_embd = EMBED_DIM
args.training.learning_rate = 0.001
args.model.hidden_layer_size = 0
args.model.apply_cot = False

model = build_model(args.model)
model.cuda()
model.train()

starting_step = 0

n_dims = model.n_dims
bsize = args.training.batch_size
data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
task_sampler = get_task_sampler(
    args.training.task,
    n_dims,
    bsize,
    num_tasks=args.training.num_tasks,
    **args.training.task_kwargs,
)
pbar = tqdm(range(starting_step, args.training.train_steps))

i = 0
data_sampler_args = {}
task_sampler_args = {}


xs = data_sampler.sample_xs(
    curriculum_n_points,
    bsize,
    curriculum_n_dims_truncated,
    **data_sampler_args,
)

task = task_sampler(**task_sampler_args)

# function list
# skill_set = [None, torch.Tensor((1, 1, 1, 1)), torch.Tensor((1, -1, 1, -1))]
skill_set = [None, torch.randn(INPUT_DIM, OUTPUT_DIM), torch.randn(INPUT_DIM, OUTPUT_DIM)]

# pick some indices where I'm going to put in xs
zs = torch.zeros_like(xs)

# NOTE: this function is not guaranteed to maximize number of skills in a sequence
# But that's probably a good thing
def get_xidxs(SEQ_LEN):
    # can have at most seq_len/2 (x, skill) pairs
    x_idxs = np.random.choice(SEQ_LEN-1, SEQ_LEN//2)
    # but some of them might be bad indexes i.e., might have consecutive idxs
    # or be last idx etc
    filtered_idxs = []
    for idx in x_idxs:
        if len(filtered_idxs) == 0:
            filtered_idxs.append(idx)
            continue
        elif (idx in filtered_idxs) or ((idx+1) in filtered_idxs) or (idx-1) in filtered_idxs:
            continue
        filtered_idxs.append(idx)
    return np.sort(np.array(filtered_idxs))

# for now do this process for each element in a batch
for batch_idx in range(BATCH_SIZE):
    x_idxs = get_xidxs(SEQ_LEN)
    zs[batch_idx, x_idxs, :] = xs[batch_idx, x_idxs, :]
    skill_idxs = x_idxs + 1
    zs[batch_idx, skill_idxs, :] = torch.randint_like(zs[batch_idx, skill_idxs, :], low=1, high=3)
    zs[batch_idx, skill_idxs, 1:] = torch.zeros_like(zs[batch_idx, skill_idxs, 1:])

# now construct y which is the target
ys = torch.zeros(zs.shape[0], zs.shape[1], OUTPUT_DIM)
# first handle 1s, then 2s
for skill_idx in [1, 2]:
    skill = skill_set[skill_idx]
    skill_locs = torch.where(zs[:, :, 0] == skill_idx, 1, 0).unsqueeze(dim=-1)
    # need to look at idxs just above skill_locs so i can multiply the x with those skills
    # this might end badly, but let's try
    x_locs = torch.roll(skill_locs, -1, dims=1)
    ys += (x_locs * zs) @ (skill.view(INPUT_DIM, OUTPUT_DIM))

# now roll the ys so that they line up with skill, not x
ys = torch.roll(ys, 1, dims=1)


# now just train on this
optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)

for step in range(200000):
    xs = data_sampler.sample_xs(
        curriculum_n_points,
        bsize,
        curriculum_n_dims_truncated,
        **data_sampler_args,
    )

    # pick some indices where I'm going to put in xs
    zs = torch.zeros_like(xs)

    # for now do this process for each element in a batch
    for batch_idx in range(BATCH_SIZE):
        x_idxs = get_xidxs(SEQ_LEN)
        zs[batch_idx, x_idxs, :] = xs[batch_idx, x_idxs, :]
        skill_idxs = x_idxs + 1
        zs[batch_idx, skill_idxs, :] = torch.randint_like(zs[batch_idx, skill_idxs, :], low=1, high=3)
        zs[batch_idx, skill_idxs, 1:] = torch.zeros_like(zs[batch_idx, skill_idxs, 1:])

    # now construct y which is the target
    ys = torch.zeros(zs.shape[0], zs.shape[1], OUTPUT_DIM)
    # first handle 1s, then 2s
    for skill_idx in [1, 2]:
        skill = skill_set[skill_idx]
        skill_locs = torch.where(zs[:, :, 0] == skill_idx, 1, 0).unsqueeze(dim=-1)
        # need to look at idxs just above skill_locs so i can multiply the x with those skills
        # this might end badly, but let's try
        x_locs = torch.roll(skill_locs, -1, dims=1)
        ys[:, :, :] += (x_locs * zs[:, :, :]) @ (skill.view(INPUT_DIM, OUTPUT_DIM))
        # now roll the ys so that they line up with skill, not x
    # now roll the ys so that they line up with skill, not x
    ys = torch.roll(ys, 1, dims=1)

    # loss_func = task.get_training_metric()
    loss_func = nn.MSELoss()

    optimizer.zero_grad()
    output = model(zs.cuda(), None)
    loss = loss_func(ys.cuda(), output.cuda())
    loss.backward()
    optimizer.step()

    if (step % 100 == 0):
        print("Step={} | Loss = {}".format(step, loss.item()))

    if (step % 100000 == 0):
        torch.save(model.state_dict(), "skill_cot_model_step_{}.pt".format(step))


torch.save(model.state_dict(), "skill_cot_model.pt")


# evaluate model
# test_x = xs[0].clone()
# test_y = torch.zeros_like(test_x)
# test_x[1::2, :] = torch.randint_like(test_x[1::2, :], low=1, high=3)
# test_x[1::2, 1:] = torch.zeros_like(test_x[1::2, 1:])

# output = model(test_x.cuda(), None)
# predictions = output.clone()
# predictions[::2, :] = torch.zeros_like(predictions[::2, :])

def tensor_to_str(t):
    str_list = ["{:.3f}".format(val) for val in t.flatten()]
    tensor_str = " ".join(str_list)
    return tensor_str


def evaluate_model(model, TEST_BATCH_SIZE=1, print_predictions=True):
    xs = data_sampler.sample_xs(
        curriculum_n_points,
        TEST_BATCH_SIZE,
        curriculum_n_dims_truncated,
        **data_sampler_args,
    )

    # pick some indices where I'm going to put in xs
    zs = torch.zeros_like(xs)
    # populate skills
    for batch_idx in range(TEST_BATCH_SIZE):
        x_idxs = get_xidxs(SEQ_LEN)
        zs[batch_idx, x_idxs, :] = xs[batch_idx, x_idxs, :]
        skill_idxs = x_idxs + 1
        zs[batch_idx, skill_idxs, :] = torch.randint_like(zs[batch_idx, skill_idxs, :], low=1, high=3)
        zs[batch_idx, skill_idxs, 1:] = torch.zeros_like(zs[batch_idx, skill_idxs, 1:])

    # now construct y which is the target
    ys = torch.zeros(zs.shape[0], zs.shape[1], 1)
    # first handle 1s, then 2s
    for skill_idx in [1, 2]:
        skill = skill_set[skill_idx]
        skill_locs = torch.where(zs[:, :, 0] == skill_idx, 1, 0).unsqueeze(dim=-1)
        # need to look at idxs just above skill_locs so i can multiply the x with those skills
        # this might end badly, but let's try
        x_locs = torch.roll(skill_locs, -1, dims=1)
        ys[:, :, :] += (x_locs * zs[:, :, :]) @ (skill.view(INPUT_DIM, OUTPUT_DIM))
        # now roll the ys so that they line up with skill, not x
    ys = torch.roll(ys, 1, dims=1)

    output = model(zs.cuda(), None)
    predictions = output.clone()

    if print_predictions:
        print("Input: {}".format(zs))
        for batch_idx in range(TEST_BATCH_SIZE):
            for idx in range(SEQ_LEN):
                found_x = False
                found_skill = False
                if torch.norm(zs[batch_idx, idx, 1:]) > 0:
                    # this is an "x"
                    x = zs[batch_idx, idx]
                    x_str = "[" + "\t".join(["{:.2f}".format(a) for a in x]) + "]"
                    found_x = True
                    found_skill = False
                    skill_idx = 0
                elif torch.norm(zs[batch_idx, idx, :]) == 0:
                    # this is a "garbage" row
                    x = zs[batch_idx, idx]
                    x_str = "[" + "\t".join(["{:.2f}".format(a) for a in x]) + "]"
                    found_x = False
                    found_skill = False
                    skill_idx = 0
                else:
                    # this is a skill
                    skill_idx = zs[batch_idx, idx, 0]
                    found_skill = True
                    found_x = False
                    x = zs[batch_idx, idx-1]
                    x_str = "[" + "\t".join(["{:.2f}".format(a) for a in x]) + "]"
                print("Input: {} \t\t| Skill: {:.3f} \t | Prediction: {:.3f} \t| Actual y: {:.3f}".format(x_str, skill_idx, predictions[batch_idx, idx].item(), ys[batch_idx, idx].item()))
            print("\n\n")

    # compute error by skills
    skill_loss_list = [0, 0, 0]
    for skill_idx in [0, 1, 2]:
        skill = skill_set[skill_idx]
        skill_locs = torch.where(zs[:, :, 0] == skill_idx, 1, 0).unsqueeze(dim=-1).cuda()
        skill_loss = loss_func((ys.cuda()*skill_locs), (predictions*skill_locs))
        skill_loss_list[skill_idx] = skill_loss.item()
        print("MSE for skill {} = {}".format(skill_idx, skill_loss.item()))

    return skill_loss_list


# @deprecated
# was useful when skills were mod2 and i needed to shuffle things
def evaluate_shuffled_model(model):
    xs = data_sampler.sample_xs(
        curriculum_n_points,
        bsize,
        curriculum_n_dims_truncated,
        **data_sampler_args,
    )

    test_x = xs[0].clone()
    test_x[1::2, :] = torch.randint_like(test_x[1::2, :], low=1, high=3)
    test_x[1::2, 1:] = torch.zeros_like(test_x[1::2, 1:])

    output = model(test_x.cuda(), None)
    predictions = output.clone()
    predictions[::2, :] = torch.zeros_like(predictions[::2, :])

    # now construct y which is the target
    test_y = torch.zeros(test_x.shape[0], 1)

    # first handle 1s, then 2s
    for skill_idx in [1, 2]:
        skill = skill_set[skill_idx]
        skill_locs = torch.where(test_x[1::2, 0] == skill_idx, 1, 0).unsqueeze(dim=-1)
        test_y[1::2, :] += (skill_locs * test_x[0::2, :]) @ (skill.view(4, 1))

    print("Input: {}".format(test_x))
    for idx, pred in enumerate(predictions.tolist()):
        if idx % 2 == 0:
            x = "[" + "\t".join(["{:.2f}".format(a) for a in test_x[idx]]) + "]"
            skill = 0
        else:
            x = "[" + "\t".join(["{:.2f}".format(a) for a in test_x[idx-1]]) + "]"
            skill = test_x.tolist()[idx][0]
        y = test_y.tolist()[idx][0]
        print("Input: {} \t\t| Skill: {:.3f} \t | Prediction: {:.3f} \t| Actual y: {:.3f}".format(x, skill, pred[0], y))

    print("\n\n")
    # compute error by skills
    skill_loss_list = [0, 0, 0]
    for skill_idx in [1, 2]:
        skill = skill_set[skill_idx]
        skill_locs = torch.where(test_x[1::2, 0] == skill_idx, 1, 0).unsqueeze(dim=-1).cuda()
        skill_loss = loss_func((test_y[1::2, :].cuda()*skill_locs), (predictions[1::2, :]*skill_locs))
        skill_loss_list[skill_idx] = skill_loss.item()
        print("MSE for skill {} = {}".format(skill_idx, skill_loss.item()))

    print("\n\n Now shuffle the inputs \\n\n")
    # shuffle things:
    for shuffle_idx in range(10):
        shuffled_x = torch.zeros_like(test_x)
        shuffled_y = torch.zeros_like(test_y)
        shuffled_x[shuffle_idx] = test_x[0]
        shuffled_x[shuffle_idx+1] = test_x[1]
        shuffled_y[shuffle_idx+1] = test_y[1]
        original_error = loss_func(test_y[1].cuda(), predictions[1].cuda())

        shuffled_output = model(shuffled_x.cuda(), None)
        pred = shuffled_output[shuffle_idx+1]
        shuffled_error = loss_func(pred.cuda(), test_y[1].cuda())
        # print(test_y[1], pred)

        print("Shuffling from 0->{} | Original error: {} | New error: {}".format(shuffle_idx, original_error, shuffled_error))
    return skill_loss_list

evaluate_model(model, TEST_BATCH_SIZE=64, print_predictions=False)
