import os
from random import randint
import uuid
import numpy as np
from tqdm import tqdm
import torch
import yaml
import itertools
import random
import wandb
import copy
import json
import time
import ast
import sys
from collections import deque
import statistics
import math
from torch.utils.data import DataLoader

from src.eval import get_run_metrics
import src.continuous_task as continuous_task
import src.bool_task as bool_task
from src.samplers import get_continuous_data_sampler
from src.curriculum import Curriculum
from src.args import build_parser
from src.models import build_model
from src.attention_analysis import prefix_scoring_step, nn_scoring_step
from src.utils import model_dist, model_sim
from src.remove_pt import delete_pt_files
from src.split import split_dataset
from src.loader import SyntheticDataset
from src.ncl import get_optimal_ncl_acc, get_optimal_ncl_loss, get_optimal_ncl_bool_loss_from_true_no_context_function, get_optimal_ncl_bool_loss

torch.backends.cudnn.benchmark = True

def update_namespace_with_json(args):
    with open(args.json_file_path, 'r') as f:
        json_data = json.load(f)
    args_dict = vars(args)
    for key, value in json_data.items():
        if key in args_dict:
            args_dict[key] = value
        else:
            setattr(args, key, value)
    return args

def train_step(task_name_list, y_format_list, model, xs_list, ys_list, optimizer, loss_func_list, task_train_list, start_idx=0, precision='full'):
    optimizer.zero_grad()
    loss_list = []
    output_list = []
    last_point_loss_list = []
    loss = 0.
    for xs, ys, loss_func, task_train, task, y_format in zip(xs_list, ys_list, loss_func_list, task_train_list, task_name_list, y_format_list):
        if torch.isnan(xs).any() or torch.isnan(ys).any():
            raise ValueError("Input data contains NaN values.")
        output = model(xs, ys, y_format)
        if precision == 'half':
            ys = ys.to(torch.float16)
        
        if 'retrieval' in task:
            loss_ = loss_func(output[:, -1], ys[:, -1])
        else:
            loss_ = loss_func(output[:, start_idx:], ys[:, start_idx:])

        last_point_loss_ = loss_func(output[:, -1], ys[:, -1])

        if task_train.ncl_opt_loss is not None:
            loss_ = loss_ / task_train.ncl_opt_loss
            last_point_loss_ = last_point_loss_ / task_train.ncl_opt_loss
        
        loss += loss_
        loss_list.append(loss_.detach().item())
        output_list.append(output.detach())
        last_point_loss_list.append(last_point_loss_.detach().item())

    loss.backward()
    optimizer.step()
    return loss_list, output_list, last_point_loss_list

def sample_seeds(total_seeds, count):
    seeds = set()
    while len(seeds) < count:
        seeds.add(randint(0, total_seeds - 1))
    return seeds

def track_gradient_norms(model):
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm(2).item()
            grad_norms[name] = grad_norm
    return grad_norms

def train(model, args):
    if args.precision == 'half':
        if args.optimizer == 'adam':
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), eps=1e-4, lr=args.learning_rate)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
        else:
            print('Please choose sgd or adam')
    else:
        if args.weight_decay:
            optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=0.01)
        else:
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate)
   
    curriculum = Curriculum(args)
    task_list = args.task_list
    device = model.device

    if args.pretrained is not None:
        state = torch.load(args.pretrained)
        model.load_state_dict(state)

    init_model = copy.deepcopy(model)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()

    n_dims = model.module.n_dims if isinstance(model, torch.nn.DataParallel) else model.n_dims

    start_idx = 0
    data_sampler = get_continuous_data_sampler('gaussian', n_dims=n_dims)

    pbar = tqdm(range(starting_step, args.train_steps))

    table_result_path = './result.json'

    with open(table_result_path, 'r') as f:
        result_file = json.load(f)

    train_loss_list_of_list = {task['exp_name']: deque([float('inf')] * 100, maxlen=100) for task in task_list}
    test_loss_list_of_list = {task['exp_name']: deque([float('inf')] * 100, maxlen=100) for task in task_list if task['data'] == 'gaussian'}
    test_acc_list_of_list = {task['exp_name']: deque([0] * 100, maxlen=100) for task in task_list if task['data'] == 'boolean'}

    all_task_name = [task['exp_name'] for task in task_list]
    
    for mixed_tasks in result_file:
        if set(all_task_name) == set(ast.literal_eval(mixed_tasks)):
            for single_task in result_file[mixed_tasks]:
                if result_file[mixed_tasks][single_task]["plateau"] is not None:
                    print('this mixed case experiment is already reported!')
                    sys.exit(0)

    result_file[str(all_task_name)] = {task: {"plateau": None, "complete": None} for task in all_task_name}

    with open(table_result_path, 'w') as f:
        json.dump(result_file, f, indent=4)

    for i in pbar:
        task_train_list = []
        for task in task_list:
            if 'bool' not in task['data']:
                task_sampler = continuous_task.get_task_sampler(n_dims, n_points=curriculum.n_points, **task)
                task_train_list.append(task_sampler())
            else:
                task_sampler = bool_task.get_task_sampler(n_dims, n_points=curriculum.n_points, **task)
                task_train_list.append(task_sampler())
        
        if i == 0 and args.check_ncl:
            # Check NCL
            bool_ys_list = []
            cont_ys_list = []

            cont_ncl_loss_list = []
            bool_ncl_loss_list = []

            cont_xs = None
            bool_xs = None
            for t, task_train in enumerate(task_train_list):
                if 'bool' not in task_list[t]['data']:
                    if cont_xs is None:
                        cont_xs = data_sampler.sample_xs(args.ncl_bsize, 1, curriculum.n_dims_truncated)

                    
                    wb = task_train.sample_wb(args.ncl_wsize)
                    ys = task_train.evaluate_by_given_wb(cont_xs, wb)
                    individual_ncl_loss = get_optimal_ncl_loss([ys.T])
                    cont_ncl_loss_list.append(individual_ncl_loss)
                    if 'relu' in task_list[t]['task']:
                        task_train.NCL_optimal_loss(individual_ncl_loss)
                    cont_ys_list.append(ys.T)

                    if args.wandb:
                        wandb.log(
                        {
                            f"NCL optimum loss of -{task_list[t]['exp_name']}": individual_ncl_loss
                        }, step=0
                    )
                
                else:
                    if bool_xs is None:
                        bool_xs = task_train.sample_xs(args.ncl_bsize,1)
                    wb = task_train.sample_wb(args.ncl_wsize)
                    ys = task_train.evaluate_by_given_wb(bool_xs, wb)
                    individual_ncl_acc = get_optimal_ncl_acc([ys.T])
                    individual_ncl_loss = get_optimal_ncl_bool_loss(task_list[t]['exp_name'])

                    task_train.NCL_optimal_loss(individual_ncl_loss)

                    bool_ncl_loss_list.append(individual_ncl_acc)
                    bool_ys_list.append(ys.T)

                    if args.wandb:
                        wandb.log({f"NCL optimum accuracy of -{task_list[t]['exp_name']}": individual_ncl_acc,
                                   f"NCL optimum loss of -{task_list[t]['exp_name']}": individual_ncl_loss}, step=0)
            
            if args.wandb:
                if len(bool_ys_list) != 0:
                    _bool_mixed_ncl_acc = get_optimal_ncl_acc(bool_ys_list)
                    wandb.log({f"Mixed NCL optimum accuracy": _bool_mixed_ncl_acc}, step=0)
                if len(cont_ys_list) != 0:
                    _cont_mixed_ncl_loss = get_optimal_ncl_loss(cont_ys_list, cont_ncl_loss_list)
                    wandb.log({f"Mixed NCL optimum loss": _cont_mixed_ncl_loss}, step=0)

        xs_list = []
        ys_list = []
        loss_func_list = []
        y_format_list = []
        task_name_list = []

        for t, task_train in enumerate(task_train_list):
            y_format_list.append(task_list[t]['y_format'])
            task_name_list.append(task_list[t]['task'])
            if 'retrieval' in task_list[t]['task']:
                xs, ys = task_train.sample_xs_and_ys(curriculum.n_points, task_list[t]['batch_size'])
                xs_list.append(xs.to(device))
                ys_list.append(ys.to(device))
                loss_func_list.append(task_train.get_training_metric())
                continue

            if 'bool' not in task_list[t]['data']:
                xs = data_sampler.sample_xs(curriculum.n_points, task_list[t]['batch_size'], curriculum.n_dims_truncated)
            else:
                xs = task_train.sample_xs(curriculum.n_points, task_list[t]['batch_size'])
        
            xs_list.append(xs.to(device))
            ys_list.append(task_train.evaluate(xs).to(device))
            loss_func_list.append(task_train.get_training_metric())

        loss_list, output_list, last_point_loss_list = train_step(task_name_list, y_format_list, model, xs_list, ys_list, optimizer, loss_func_list, task_train_list, start_idx, precision=args.precision)

        mean_acc_list = []
        last_ten_acc_list = []
        last_acc_list = []

        for ii, task_train in enumerate(task_train_list):
            point_wise_loss_func = task_train.get_metric()
            func_name = point_wise_loss_func.__name__
            
            if func_name == "squared_error":
                mean_acc_list.append(0)
                last_ten_acc_list.append(0)
                last_acc_list.append(0)
            else:
                XXX = point_wise_loss_func(output_list[ii], ys_list[ii].to(device))
                point_wise_loss = XXX.mean(dim=0)
                mean_acc = point_wise_loss.mean().item()
                last_ten_acc = point_wise_loss[-10:].mean().item()
                last_acc = point_wise_loss[-1].item()
                mean_acc_list.append(mean_acc)
                last_ten_acc_list.append(last_ten_acc)
                last_acc_list.append(last_acc)

        if args.wandb:
            init_distance = model_dist(curr_model=model, init_model=init_model, weight_only=True)
            
            save_condition = 1
            update = 0
            for ii, task in enumerate(task_list):
                task_name = task['exp_name']
                train_loss_list_of_list[task_name].append(loss_list[ii])
                if task['data'] == 'gaussian':
                    test_loss_list_of_list[task_name].append(last_point_loss_list[ii])
                else:
                    test_acc_list_of_list[task_name].append(last_acc_list[ii])

                tracking_loss_list = list(train_loss_list_of_list[task_name])
                if statistics.mean(tracking_loss_list) < 0.8:
                    if result_file[str(all_task_name)][task_name]["plateau"] is None:
                        result_file[str(all_task_name)][task_name]["plateau"] = i
                        save_condition *= 0
                        update += 1
                    else:
                        save_condition *= 1
                else:
                    save_condition *= 0

                if task['data'] == 'gaussian':
                    threshold = 0.2 if task['exp_name'] == 'quadratic_regression' else 0.1

                    if statistics.mean(test_loss_list_of_list[task_name]) < threshold:
                        if result_file[str(all_task_name)][task_name]["complete"] is None:
                            result_file[str(all_task_name)][task_name]["complete"] = i
                            save_condition *= 0
                            update += 1
                        else:
                            save_condition *= 1
                    else:
                        save_condition *= 0

                else:
                    if statistics.mean(test_acc_list_of_list[task_name]) > 0.95:
                        if result_file[str(all_task_name)][task_name]["complete"] is None:
                            result_file[str(all_task_name)][task_name]["complete"] = i
                            save_condition *= 0
                            update += 1
                        else:
                            save_condition *= 1
                    else:
                        save_condition *= 0
                
                wandb.log(
                    {
                        f"mean_acc": sum(mean_acc_list) / len(mean_acc_list),
                        f"mean_acc-{task['exp_name']}": mean_acc_list[ii],
                        f"last_acc-{task['exp_name']}": last_acc_list[ii],
                        f"overall_loss-{task['exp_name']}": loss_list[ii],
                        f"last_point_loss-{task['exp_name']}": last_point_loss_list[ii],
                    },
                    step=i,
                )

        if update > 0:
            with open(table_result_path, 'w') as f:
                json.dump(result_file, f, indent=4)

        if save_condition:
            sys.exit(0)
        wandb.log({"init_distance": init_distance, "total_loss": sum(loss_list)}, step=i)

        if args.save_iteration is not None:
            if i == args.save_iteration:
                torch.save(model.state_dict(), os.path.join(args.out_dir, f"{args.name}_model_{i}.pt"))
        curriculum.update()

        pbar.set_description(f"loss {sum(loss_list)}")

def main(args):
    if args.test_run:
        args.curriculum_points_start = args.curriculum_points_end
        args.curriculum_dims_start = args.curriculum_dims_end
        args.train_steps = 100
    else:
        args.curriculum_dims_end = args.n_dims
        args.curriculum_dims_start = args.curriculum_dims_end
        if args.wandb:
            wandb.init(
                dir=args.out_dir,
                project=args.project,
                entity=args.entity,
                config=args.__dict__,
                notes=args.notes,
                name=args.name,                
                resume=True,
            )

    if isinstance(args.gpu, list):
        device = torch.device(f"cuda:{args.gpu[0]}")
        available_gpus = ','.join(map(str, args.gpu))
        torch.cuda.set_device(device)
        print(f"Using GPUs: {available_gpus}")
    else:
        device = torch.device(f"cuda:{args.gpu}")

    model = build_model(args)

    if isinstance(args.gpu, list) and len(args.gpu) > 1:
        model = torch.nn.DataParallel(model, device_ids=args.gpu)

    model.to(device)
    model.device = device
    model.train()

    train(model, args)

    if not args.test_run:
        eval_metrics = get_run_metrics(args.out_dir, device=device)

    if args.wandb:
        eval_metrics = eval_metrics['standard']
        eval_models = list(eval_metrics.keys())
        plot_y = []

        val_acc = eval_metrics[model.module.name if isinstance(model, torch.nn.DataParallel) else model.name]['mean']
        mean_val_acc = np.mean(val_acc)

        wandb.log({"mean_val_acc": mean_val_acc})

        for model_name in eval_models:
            plot_y.append(eval_metrics[model_name]['mean'])
        plot_x = list(range(len(plot_y[0])))

        wandb.log({'eval/mean_acc': wandb.plot.line_series(plot_x, plot_y, keys=eval_models, title='Accuracy of Different Models', xname='Incontext Examples')})
    
    if args.delete:
        print('Deleting model (pt) files...')
        delete_pt_files(args.out_dir)

if __name__ == "__main__":
    parser = build_parser()
    args = parser.parse_args()
    if args.json_file_path != "":
        args = update_namespace_with_json(args)

    if not args.test_run:
        run_id = args.resume_id
        if run_id == "":
            run_id = str(uuid.uuid4())[:20]
            args.name += '_' + args.family
            if args.family in ['gpt', 'mysan', 'attclf']:
                args.name += '_' + args.model_name
            args.name += '_' + run_id[:8]

        a = ''
        for task in args.task_list:
            a += '+' + task['task']
        args.out_dir = os.path.join(args.out_dir, a)
        out_dir = args.out_dir + '_' + args.family
        out_dir = os.path.join(args.out_dir, args.name)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        args.out_dir = out_dir

        with open(os.path.join(out_dir, "config.yaml"), "w") as yaml_file:
            yaml.dump(args.__dict__, yaml_file, default_flow_style=False)

    main(args)
