import argparse
from tqdm import trange, tqdm

from torch.utils.data import DataLoader
from torchmetrics.classification import BinaryAccuracy
import copy
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import os
from iclr_code.markov_data import *
from models import *
from pathlib import Path
import warnings
import wandb

warnings.filterwarnings("ignore")

def get_args():

    parser = argparse.ArgumentParser(
        description='Training Transformers for Bayesian Inference')
    
    # data args
    parser.add_argument(
        '--num-example',
        help='Number of examples for in-context learning',
        type=int,
        default=500)
    parser.add_argument(
        '--num-var',
        help='Number of variables in Markov Chain',
        type=int,
        default=10)
    
    # training args
    parser.add_argument(
        '--steps',
        help='Number of training steps',
        type=int,
        default=4000)
    parser.add_argument('--batch-size', help='Batch size', type=int, default=64)
    
    # model args
    parser.add_argument(
        '--layers',
        help='Number of transformer layers',
        type=int,
        default=12)
    parser.add_argument(
        '--heads',
        help='Number of transformer attention heads',
        type=int,
        default=4)
    parser.add_argument(
        '--hid_dim',
        help='Size of hidden dimension',
        type=int,
        default=128)
    parser.add_argument(
        '--model',
        help='Size of hidden dimension',
        type=str,
        default="transformer")
    
    # log args
    parser.add_argument(
        '--log-every',
        help='log every X steps',
        type=int,
        default=10)
    
    args = parser.parse_args()

    return args
# torch.manual_seed(1111)

def run_test(args):

    if args.wandb:
        run = wandb.init(
            # Set the project where this run will be logged
            project="Transformers-Bayesian-Inference-Test-Markov-Chain",
            # Track hyperparameters and run metadata
            config={
                "steps": args.steps,
                "hid_dim": args.hid_dim,
                "heads": args.heads,
                "layers": args.layers,
                "num_example": args.num_example,
                "num variables":args.num_var
            },
        )

    eval_data = {
        "# example":[],
        "prob difference":[],
        "transformer acc":[],
        "graph acc":[],
    }
    metric = BinaryAccuracy().cuda()

    eval_num = 1500
    n_examples = [2, 3, 4, 5, 10, 20, 30, 40, 50, 75, 100]

    model = EncoderTransformer(
        n_dims = args.num_var*3,
        n_embd=args.hid_dim,
        n_layer=args.layers,
        n_head=args.heads)
    model = model.cuda()
    log_path = f"ckpt/markov_chain_{args.num_var}/{args.model}_{args.num_example}_{args.layers}_{args.hid_dim}_{args.heads}_{args.batch_size}_seed{args.seed}/"
    ckpt_path = f"{log_path}/ckpt.pt"
    model.load_state_dict(torch.load(ckpt_path, weights_only=True))
    testset = InContextDatasetTest(eval_num, n_examples[-1], mask_out=0, num_var=args.num_var)

    for var_idx in range(args.num_var):

        print("Running Evaluation on Masking Variable:", var_idx)
        eval_data = {
            "# example":[],
            "prob difference":[],
            "transformer acc":[],
            "graph acc":[]
        }
        testset.update_mask_out(var_idx)
        test_loader = DataLoader(testset, batch_size=4, shuffle=False)

        prob_diff = {n:0 for n in n_examples}
        model_cor = {n:0 for n in n_examples}
        graph_cor = 0

        for x, y, probs, graph_pred in test_loader:

            x, y, probs, graph_pred = x.cuda(), y.type(torch.LongTensor).cuda().squeeze(-1), probs.cuda(), graph_pred.cuda()
            for n in n_examples:
          
                inp_x = torch.cat([x[:, :n, :], x[:, -1, :].unsqueeze(1)], dim=1) 
                output = model(inp_x)
                pred_prob = torch.softmax(output, dim=-1)
                model_pred = pred_prob.argmax(-1)
                model_cor[n] += (metric(model_pred, y)*(x.size(0))).item()
                if n == 10:
                    graph_cor +=  (metric(graph_pred, y)*(x.size(0))).item()
                if var_idx == 0:
                    true_prob = probs[:, 0]
                else:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]

                pred_prob = pred_prob[:, 0]
                prob_diff[n] += (torch.abs(pred_prob - true_prob)).sum().item()

        for n in n_examples:
            eval_data["# example"].append(n)
            eval_data["prob difference"].append(np.sum(prob_diff[n])/eval_num)
            eval_data["transformer acc"].append(np.sum(model_cor[n])/eval_num)
            eval_data["graph acc"].append(graph_cor/eval_num)

        if args.wandb:
            wandb.log({
              f"Var {var_idx} # example": n,
              f"Var {var_idx} prob difference": np.sum(prob_diff[n])/eval_num,
              f"Var {var_idx} transformer acc": np.sum(model_cor[n])/eval_num,
              f"Var {var_idx} graph acc": graph_cor/eval_num
            })
    
        eval_data = pd.DataFrame(eval_data)
        eval_data.to_csv(f"{log_path}/var_{var_idx}_eval.csv", index=False)
    if args.wandb:
        wandb.finish()

# args = get_args()
# run(args)
