import os
import shutil
from random import randint
import uuid
import datetime

from quinine import QuinineArgumentParser
from tqdm import tqdm
import torch
import yaml

from eval import get_run_metrics
from tasks import get_task_sampler
from samplers import get_data_sampler
from curriculum import Curriculum
from schema import schema
from models import build_model

from plot_utils import basic_plot, collect_results, relevant_model_names
from eval_utils import get_evaluation_df

import wandb
import pandas as pd

torch.backends.cudnn.benchmark = True


def warmup_step(model, xs, IDs, optimizer, loss_func):
    optimizer.zero_grad()
    losses, loss = model(xs, IDs, loss_func)
    losses[1].backward()
    optimizer.step()
    losses = [loss.detach().item() for loss in losses]
    return losses, loss.detach().item()


def train_step(model, xs, ys, optimizer, loss_func, layer_activations=None):
    optimizer.zero_grad()
    # losses, loss = model.forward_circulant(xs, ys, loss_func, layer_activations=layer_activations)
    losses, loss = model(xs, ys, loss_func, layer_activations=layer_activations)
    loss.backward()
    optimizer.step()
    losses = [loss.detach().item() for loss in losses]
    return losses, loss.detach().item()


def sample_seeds(total_seeds, count):
    seeds = set()
    while len(seeds) < count:
        seeds.add(randint(0, total_seeds - 1))
    return seeds


def train(model, args):
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples


    data_sampler_args = {}
    task_sampler_args = {}
    if num_training_examples is not None:
        assert num_training_examples >= bsize
        seeds = sample_seeds(num_training_examples, bsize)
        data_sampler_args["seeds"] = seeds
        task_sampler_args["seeds"] = [s + 1 for s in seeds]
    task = task_sampler(**task_sampler_args)
    for i in pbar:

        xs = data_sampler.sample_xs(
            curriculum.n_points,
            bsize,
            curriculum.n_dims_truncated,
            **data_sampler_args,
        )

        loss_func = task.get_training_metric()
        if args.training.task in ['relu_nn_regression', 'long_chain']:
            task = task_sampler(**task_sampler_args)
            ys, layer_activations = task.evaluate(xs)
            if not model.n_intermediate_activations:
                # xs = layer_activations[0].cuda()
                layer_activations = None
            else:
                layer_activations = [act.cuda() for act in layer_activations]
            losses, loss = train_step(model, xs.cuda(), ys.cuda(), optimizer,
                loss_func, layer_activations=layer_activations)
        # TODO
        elif args.training.task in ['cot_skill_chain']:
            xs, ys, _, _ = task.evaluate(xs)
            if i < args.training.warmup_steps:
                losses, loss = warmup_step(model, xs.cuda(), ys.cuda(), optimizer, loss_func)
            else:
                losses, loss = train_step(model, xs.cuda(), ys.cuda(), optimizer, loss_func)
        else:
            raise NotImplementedError

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "stepwise/loss": dict(
                        zip(list(range(len(losses))), losses)
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
        ) or (i == args.training.train_steps - 1):
            torch.save(model.state_dict(), os.path.join(args.out_dir, f"model_{i}.pt"))


def main(args):
    if args.test_run:
        curriculum_args = args.training.curriculum
        curriculum_args.points.start = curriculum_args.points.end
        curriculum_args.dims.start = curriculum_args.dims.end
        args.training.train_steps = 100
    else:
        wandb.init(
            dir=args.out_dir,
            project=args.wandb.project,
            entity=args.wandb.entity,
            config=args.__dict__,
            notes=args.wandb.notes,
            name=args.wandb.name,
            mode="disabled" if args.debug_mode else "online",
            resume=True,
        )

    model = build_model(args.model)
    model.cuda()
    model.train()

    train(model, args)

    print("Skipping computation of baselines")
    # if not args.test_run:
    #     metrics = get_run_metrics(args.out_dir)  # precompute metrics for eval
    #     # process results dict into a nice dataframe
    #     results_dict = {'model': [], 'mean_acc': []}
    #     for model in metrics["standard"].keys():
    #         results_dict['model'].append(model)
    #         # only consider the case when we have the "full prefix"
    #         results_dict['mean_acc'].append(metrics["standard"][model]['mean'][-1])
    #     results_df = pd.DataFrame(results_dict)
    #     print("Test accuracy: \n {}".format(results_df))
    #     print("Saving outputs to: {}".format(args.out_dir))
    #     results_df.to_csv("{}/results_df.csv".format(args.out_dir))
    # TODO: @KS: this function is specific to linear regression, maybe I can modify this
    # TODO: right now, it crashes, so commenting it out
    # model_error_df = get_evaluation_df(model, args)

    if args.debug_mode:
        # delete wandb directory when done
        print("Deleting out_dir {} because of debug mode".format(args.out_dir))
        shutil.rmtree("{}".format(args.out_dir), ignore_errors=True)


if __name__ == "__main__":
    parser = QuinineArgumentParser(schema=schema)
    args = parser.parse_quinfig()
    assert args.model.family in ["gpt2_nn", "gpt2_skill"]
    print(f"Running with: {args}")

    if args.debug_mode:
        args.out_dir = "../models/debug"

    if not args.test_run:
        run_id = args.training.resume_id
        if run_id is None:
            run_id = str(uuid.uuid4())

        out_dir = os.path.join(args.out_dir, run_id)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        args.out_dir = out_dir
        # add a timestamp here
        args.wandb['timestamp'] = datetime.datetime.now().strftime("%m/%d/%Y, %H:%M:%S")

        with open(os.path.join(out_dir, "config.yaml"), "w") as yaml_file:
            yaml.dump(args.__dict__, yaml_file, default_flow_style=False)

    main(args)
