import os
from random import randint
import uuid

from quinine import QuinineArgumentParser
from tqdm import tqdm
import torch
import yaml

from eval import get_run_metrics
from tasks import get_task_sampler
from samplers import get_data_sampler
from curriculum import Curriculum
from schema import schema
from models import build_model

import wandb

torch.backends.cudnn.benchmark = True


def train_step(model, xs, ys, optimizer, loss_func):
    optimizer.zero_grad()
    output = model(xs, ys)
    loss = loss_func(output, ys)
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()


def sample_seeds(total_seeds, count):
    seeds = set()
    while len(seeds) < count:
        seeds.add(randint(0, total_seeds - 1))
    return seeds


def train(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        data_sampler_args = {}
        task_sampler_args = {}

        if "sparse" in args.training.task:
            task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        if num_training_examples is not None:
            assert num_training_examples >= bsize
            seeds = sample_seeds(num_training_examples, bsize)
            data_sampler_args["seeds"] = seeds
            task_sampler_args["seeds"] = [s + 1 for s in seeds]

        xs = data_sampler.sample_xs(
            curriculum.n_points,
            bsize,
            curriculum.n_dims_truncated,
            **data_sampler_args,
        )
        task = task_sampler(**task_sampler_args)
        ys = task.evaluate(xs)

        loss_func = task.get_training_metric()

        loss, output = train_step(model, xs.cuda(), ys.cuda(), optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = task.get_metric()
        point_wise_loss = point_wise_loss_func(output, ys.cuda()).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model.state_dict(), os.path.join(args.out_dir, f"model_{i}.pt"))


def main(args):
    if args.test_run:
        curriculum_args = args.training.curriculum
        curriculum_args.points.start = curriculum_args.points.end
        curriculum_args.dims.start = curriculum_args.dims.end
        args.training.train_steps = 100
    else:
        wandb.init(
            dir=args.out_dir,
            project=args.wandb.project,
            entity=args.wandb.entity,
            config=args.__dict__,
            notes=args.wandb.notes,
            name=args.wandb.name,
            resume=True,
        )

    model = build_model(args.model)
    model.cuda()
    model.train()

    train(model, args)

class MyModelArgs:
    def __init__(self):
        self.family = "gpt2"
        self.n_dims = 5
        self.n_embd = 256
        self.n_head = 8
        self.n_layer  = 12
        self.n_positions = 11

class MyTrainArgs:
    def __init__(self):
        self.batch_size = 64
        self.curriculum = None # ignore curriculum for now
        self.data = "gaussian"
        self.keep_every_steps = 100000
        self.learning_rate = 0.0001
        self.num_tasks = None
        self.num_training_examples = None
        self.resume_id = None
        self.save_every_steps = 1000
        self.task = "linear_regression"
        self.task_kwargs = {}
        self.train_steps = 5001

class MyArgs:
    def __init__(self):
        self.config = "conf/toy.yaml"
        self.model = MyModelArgs()
          
        self.out_dir = "../models/linear_regression"
        self.test_run = False
        self.training = MyTrainArgs()
          
        self.wandb = None


args = MyArgs()
model = build_model(args.model)
model.cuda()
model.train()

starting_step = 0

n_dims = model.n_dims
bsize = args.training.batch_size
data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
task_sampler = get_task_sampler(
    args.training.task,
    n_dims,
    bsize,
    num_tasks=args.training.num_tasks,
    **args.training.task_kwargs,
)
pbar = tqdm(range(starting_step, args.training.train_steps))

i = 0
data_sampler_args = {}
task_sampler_args = {}

curriculum_n_points = 11
curriculum_n_dims_truncated = 5

xs = data_sampler.sample_xs(
    curriculum_n_points,
    bsize,
    curriculum_n_dims_truncated,
    **data_sampler_args,
)

# this task thing is weird because it has many w_b's
task = task_sampler(**task_sampler_args)
ys = task.evaluate(xs)

# let's start with a single w_b and just 11 xs

small_x = xs[:1].cuda()

w_b = torch.randn(5).cuda()
# try using their w_b
w_b = task.w_b[:1].cuda()

# y = w_b^T x
small_y = small_x @ w_b
# this is the same as ys[:1]

loss_func = task.get_training_metric()

# small_x = small_x.unsqueeze(1)
# small_y = small_y.unsqueeze(1)

# switch off gradients for model, switch on gradients for x
for name, param in model.named_parameters():
    param.requires_grad=False

small_x.requires_grad = True
small_y.requires_grad = True

optimizer = torch.optim.Adam([small_x, small_y], lr=1e6)
optimizer.zero_grad()

last_idx = small_y.shape[1] - 1
pred = model(small_x, small_y, inds=[last_idx])

# we only care about the loss for the last (x, y) pair
loss = torch.norm(small_y[0][last_idx] - pred)**2

loss.backward()
print("Iter -1: Loss={}".format(loss.item()))
print(small_x.grad)
print(small_y.grad)

# we don't want to change the final (x, y) pair
small_x.grad.data[0][-1] = torch.zeros_like(small_x.grad[0][-1])
small_y.grad.data[0][-1] = torch.zeros_like(small_y.grad[0][-1])

optimizer.step()

# try training on something
# make a dataset where I only change the final input, output pair
my_dataset = []
train_dataset_size = 5
for i in range(train_dataset_size):
    x_query = torch.randn_like(small_x[0][-1])
    y_query = x_query @ w_b
    my_dataset.append((x_query, y_query))

# first check average loss on this dataset
total_loss = 0
for n_iter, (x_query, y_query) in enumerate(my_dataset):
    small_x.data[0][-1] = x_query
    small_y.data[0][-1] = y_query

    last_idx = small_y.shape[1] - 1
    pred = model(small_x, small_y, inds=[last_idx])

    # we only care about the loss for the last (x, y) pair
    loss = torch.norm(small_y[0][last_idx] - pred)**2
    total_loss += loss.item()
print("Avg Train loss to start with: {}".format(total_loss/train_dataset_size))


NUM_EPOCHS = 50
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    total_grad_norm = 0
    for n_iter, (x_query, y_query) in enumerate(my_dataset):
        optimizer.zero_grad()
        small_x.data[0][-1] = x_query
        small_y.data[0][-1] = y_query

        last_idx = small_y.shape[1] - 1
        pred = model(small_x, small_y, inds=[last_idx])

        # we only care about the loss for the last (x, y) pair
        loss = torch.norm(small_y[0][last_idx] - pred)**2

        loss.backward()
        # print("Iter {}: Loss={}".format(n_iter, loss.item()))
        # print(small_x.grad)
        # print(small_y.grad)

        # we don't want to change the final (x, y) pair
        small_x.grad.data[0][-1] = torch.zeros_like(small_x.grad[0][-1])
        small_y.grad.data[0][-1] = torch.zeros_like(small_y.grad[0][-1])

        optimizer.step()
        total_loss += loss.item()
        total_grad_norm += (torch.norm(small_y.grad) + torch.norm(small_x.grad))
    print("Average loss for epoch {}: {}".format(epoch, total_loss/train_dataset_size))
    print("Average grad norm for epoch {}: {}".format(epoch, total_grad_norm/train_dataset_size))

