import argparse
import os
import random
import time
import sys
import clip

import pandas as pd
import json
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from torch.utils.tensorboard import SummaryWriter
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR

from libs import models
from libs.checkpoint import resume, save_checkpoint
from libs.class_id_map import get_n_classes
from libs.class_weight import get_class_weight, get_pos_weight
from libs.config import get_config
from libs.dataset import ActionSegmentationDataset, collate_fn
from libs.helper import train, validate, evaluate_with_return
from libs.loss_fn import ActionSegmentationLoss, BoundaryRegressionLoss, KLLoss
from libs.loss_fn.cross import TSConstraintLoss, RobustTSConstraint, KLTSConstraint

from libs.optimizer import get_optimizer
from libs.transformer import TempDownSamp, ToTensor, RandomJointMask, RandomNDRotation
from libs.swap import SegmentSwapTransform
from prompt.text_prompt import TextCLIP, text_prompt_for_class, text_prompt_for_joint, text_prompt_for_joint_custom, text_prompt_for_class_custom


def get_arguments() -> argparse.Namespace:
    """
    parse all the arguments from command line inteface
    return a list of parsed arguments
    """

    parser = argparse.ArgumentParser(
        description="train a network for action segmentation"
    )
    parser.add_argument("--dataset",  type=str, default="PKU-view", help="name of the dataset")
    parser.add_argument("--result_path", type=str, default="./result", help="path of a result")
    parser.add_argument("--cuda", type=int, default= 0, help="cuda id")
    parser.add_argument(
        "--resume", action="store_true", help="Add --resume option if you start training from checkpoint.",
    )
    parser.add_argument(
        "--resume_ckpt", type=str, default="", help="resume from specified checkpoint",
    )
    parser.add_argument(
        "--resume_log", type=str, default="", help="resume from specified log",
    )
    parser.add_argument(
        "--refinement_method",
        type=str,
        default="refinement_with_boundary",
        choices=["refinement_with_boundary", "relabeling", "smoothing"],
    )
    parser.add_argument("--alpha", type=float, default= 0, help="alpha value of curve")
    parser.add_argument("--beta", type=float, default= 1.0, help="beta value of curve")
    parser.add_argument("--cross", type=float, default= 2, help="beta value of curve")
    parser.add_argument("--config", type=str, default='config', help="config path")


    return parser.parse_args()

def import_class(import_str):
    mod_str, _sep, class_str = import_str.rpartition('.')
    __import__(mod_str)
    try:
        return getattr(sys.modules[mod_str], class_str)
    except AttributeError:
        raise ImportError('Class %s cannot be found (%s)' % (class_str, traceback.format_exception(*sys.exc_info())))

def change_label_score(best_test, train_loss, epoch, cls_acc, edit_score, f1s):

    best_test['train_loss'] = train_loss
    best_test['epoch'] = epoch
    best_test['cls_acc'] = cls_acc
    best_test['edit'] = edit_score
    best_test['f1s@0.1'] = f1s[0]
    best_test['f1s@0.25'] = f1s[1]
    best_test['f1s@0.5'] = f1s[2]
    best_test['f1s@0.75'] = f1s[3]
    best_test['f1s@0.9'] = f1s[4]

def main() -> None:

    start_start = time.time()

    # argparser
    args = get_arguments()
    dataset_name = args.dataset
    device_num = args.cuda
    alpha = args.alpha
    beta = args.beta
    cross = args.cross
    # configuration
    config = get_config(f"config/{dataset_name}/{args.config}.yaml")

    result_path =  os.path.join(args.result_path, config.dataset, 'split' + str(config.split))

    print('\n---------------------------result_path---------------------------\n')
    print('result_path:',result_path) 
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    
    
    log_dir = os.path.join(result_path, "tensorboard_logs")
    writer = SummaryWriter(log_dir)
    
    with open(f'{result_path}/scores.txt', "w") as file:
        file.write(f'The result printed:\n')

    seed = config.seed
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        torch.backends.cudnn.benchmark = True
        device = device_num #0
        output_device = device_num[0] if type(device_num) is list else device_num
        torch.cuda.set_device(output_device)
        if type(device) is list:
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_num))
        else:
            os.environ["CUDA_VISIBLE_DEVICES"] = f'{device_num}'

        current_device = torch.cuda.current_device()
        print(f"Currently using GPU {current_device}")


    downsamp_rate = 4 if config.dataset == "LARA" else 1

    train_data = ActionSegmentationDataset(
        config.dataset,
        transform=Compose([ToTensor(), 
                           TempDownSamp(downsamp_rate), 
                           ]),
        mode="trainval" if not config.param_search else "training",
        split=config.split,
        dataset_dir=config.dataset_dir,
        csv_dir=config.csv_dir,
    )

    train_loader = DataLoader(
        train_data,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers, #4
        drop_last=True if config.batch_size > 1 else False,
        collate_fn=collate_fn,
    )

    if config.param_search:
        val_data = ActionSegmentationDataset(
            config.dataset,
            transform=Compose([ToTensor(), TempDownSamp(downsamp_rate)]),
            mode="validation",
            split=config.split,
            dataset_dir=config.dataset_dir,
            csv_dir=config.csv_dir,
        )

        val_loader = DataLoader(
            val_data,
            batch_size=1,
            shuffle=False,
            num_workers=config.num_workers,
            collate_fn=collate_fn,
        )
    
    test_data = ActionSegmentationDataset(
        config.dataset,
        transform=Compose([ToTensor(), TempDownSamp(downsamp_rate)]),
        mode="test",
        split=config.split,
        dataset_dir=config.dataset_dir,
        csv_dir=config.csv_dir,
    )

    test_loader = DataLoader(
        test_data,
        batch_size=1,
        shuffle=False,
        num_workers=config.num_workers,
        collate_fn=collate_fn,
    )

    # load model
    print("---------- Loading Model ----------")

    n_classes = get_n_classes(config.dataset, dataset_dir=config.dataset_dir) 

    class_text_list = text_prompt_for_class(dataset_name,"detail") 
    joint_text_list = text_prompt_for_joint(dataset_name, "detail")

    Model = import_class(config.model)
    model_, preprocess = clip.load("ViT-B/32", "cuda" if torch.cuda.is_available() else "cpu")
    model_text = TextCLIP(model_)
    
    model_text = model_text.cuda(output_device)
    with open(f'{result_path}/scores.txt', "a+") as file:
                file.write(
                    f'Model: {config.model}\n'
                )
    model = Model( 
            in_channel=config.in_channel, #12
            n_features=config.n_features, #64
            n_classes=n_classes, #8
            n_stages=config.n_stages, #4
            n_layers=config.n_layers, #10
            n_refine_layers=config.n_refine_layers, #10
            n_stages_asb=config.n_stages_asb, #2
            n_stages_brb=config.n_stages_brb, #3
            SFI_layer=config.SFI_layer, #{1,2,3,4,5,6,7,8,9}
            dataset=config.dataset, #LARA
        )

    

    
    model.to(output_device)
    
    print("Model loaded successfully")
    optimizer = get_optimizer(
        config.optimizer,
        model,
        config.learning_rate,
        momentum=config.momentum,
        dampening=config.dampening,
        weight_decay=config.weight_decay,
        nesterov=config.nesterov,
    ) 

    scheduler=CosineAnnealingLR(
        optimizer,
        T_max=config.max_epoch, #100
        eta_min=1e-6, #1e-6
        last_epoch=-1, #-1
        verbose=True, #True
    )


    columns = ["epoch", "lr", "train_loss"]


    if config.param_search:
        columns += ["val_loss", "cls_acc", "edit"]
        columns += [
            "f1s@{}".format(config.iou_thresholds[i])
            for i in range(len(config.iou_thresholds))
        ]
        columns += ["bound_acc", "precision", "recall", "bound_f1s"]

    begin_epoch = 0
    best_loss = float("inf")


    best_test_acc =  {'epoch':0,'train_loss':0,'cls_acc':0,'edit':0,'f1s@0.1':0,\
                      'f1s@0.25':0,'f1s@0.5':0,'f1s@0.75':0,'f1s@0.9':0}
    best_test_F1_10 =  best_test_acc.copy()
    best_test_F1_50 =  best_test_acc.copy()
    best_this_acc = {}
    this_value = ['cls_acc', 'edit', 'f1s@0.1', 'f1s@0.25', 'f1s@0.5']
    for ky in this_value:
        best_this_acc[ky] = {
            'value': 0.0,
            'epoch': 0,
            'model': '',
        }

    log = pd.DataFrame(columns=columns)
    if args.resume:
        if os.path.exists(os.path.join(result_path, "checkpoint.pth")):
            checkpoint = resume(result_path, model, optimizer)
            begin_epoch, model, optimizer, best_loss = checkpoint
            log = pd.read_csv(os.path.join(result_path, "log.csv"))
            print("training will start from {} epoch".format(begin_epoch))
        else:
            print("there is no checkpoint at the result folder")
    elif args.resume_ckpt:
        if os.path.exists(args.resume_ckpt):
            checkpoint = resume(args.resume_ckpt, model, optimizer)
            begin_epoch, model, optimizer, best_loss = checkpoint
            if args.resume_log:
                log = pd.read_csv(os.path.join(args.resume_log, "log.csv"))
                log.pop('Unnamed: 0')
            print("training will start from {} epoch".format(begin_epoch))
            
        else:
            print("there is no checkpoint at the result folder")

    if config.class_weight:
        class_weight = get_class_weight(
            config.dataset,
            split=config.split,
            dataset_dir=config.dataset_dir,
            csv_dir=config.csv_dir,
            mode="training" if config.param_search else "trainval",
        )
        class_weight = class_weight.to(output_device)
    else:
        class_weight = None

    criterion_cls = ActionSegmentationLoss(
        ce=config.ce,
        focal=config.focal,
        tmse=config.tmse,
        gstmse=config.gstmse,
        weight=class_weight,
        ignore_index=255,
        ce_weight=config.ce_weight,
        focal_weight=config.focal_weight,
        tmse_weight=config.tmse_weight,
        gstmse_weight=config.gstmse,
        cl=config.cl,
        cl_weight=config.cl_weight,
        circle=config.circle,
        circle_weight=config.circle_weight,
    ).cuda(output_device)  #Including cross entropy loss and Gaussian smoothing loss

    cross_criterion = TSConstraintLoss(alpha=cross, beta=cross/2, loss_type='mse')

    pos_weight = get_pos_weight(
        dataset=config.dataset,
        split=config.split,
        csv_dir=config.csv_dir,
        mode="training" if config.param_search else "trainval",
    ).to(output_device)

    criterion_bound = BoundaryRegressionLoss(pos_weight=pos_weight, smoothed=True).cuda(output_device) #a binary cross entropy loss
    criterion_curve = BoundaryRegressionLoss(pos_weight=pos_weight, smoothed=True).cuda(output_device) #a binary cross entropy loss
    criterion_contrast = KLLoss().cuda(output_device) #contrastive loss


    # train and validate model
    print("---------- Start training ----------")
    avg_cls_acc=0
    avg_edit_score=0
    avg_segment_f1s=[0,0,0,0,0]
    avg_bound_acc=0
    avg_precision=0
    avg_recall=0
    avg_bound_f1s=0

    start_test = 0

    for epoch in range(begin_epoch, config.max_epoch):
        # training
        start = time.time()

        train_loss = train(
            train_loader,
            model,
            model_text,
            class_text_list,
            joint_text_list,
            criterion_cls,
            criterion_bound,
            criterion_contrast,
            config.lambda_b,
            optimizer,
            dataset_name,
            device,output_device,
            criterion_curve=criterion_curve,
            batch_cnt=epoch,
            alpha=alpha,
            beta=beta,
            cross_criterion=cross_criterion,
            scheduler=scheduler,
        )
        # train_loss = 0
        train_time = (time.time() - start) / 60

        # if you do validation to determine hyperparams
        if config.param_search:
            start = time.time()
            (
                val_loss,
                cls_acc,
                edit_score,
                segment_f1s,
                bound_acc,
                precision,
                recall,
                bound_f1s,
            ) = validate(
                val_loader,
                model,
                model_text,
                joint_text_list,
                criterion_cls,
                criterion_bound,
                config.lambda_b,
                device,output_device,
                config.dataset,
                config.dataset_dir,
                config.iou_thresholds,
                config.boundary_th,
                config.tolerance,
                config.refinement_method,
            )
            if (epoch>=config.max_epoch-20):
                avg_cls_acc += cls_acc/20
                avg_edit_score += edit_score/20
                avg_segment_f1s = [a + b/20 for a, b in zip(avg_segment_f1s,segment_f1s)]
                avg_bound_acc += bound_acc/20
                avg_precision += precision/20
                avg_recall += recall/20
                avg_bound_f1s += bound_f1s/20

            if (epoch >0):
                # save a model if top1 cls_acc is higher than ever
                if best_loss > val_loss:
                    best_loss = val_loss

                if cls_acc > best_test_acc['cls_acc']:
                    change_label_score(best_test_acc, train_loss, epoch, cls_acc, edit_score, segment_f1s)


                if segment_f1s[0] > best_test_F1_10['f1s@0.1']:
                    change_label_score(best_test_F1_10, train_loss, epoch, cls_acc, edit_score, segment_f1s)


                if segment_f1s[2] > best_test_F1_50['f1s@0.5']:
                    change_label_score(best_test_F1_50, train_loss, epoch, cls_acc, edit_score, segment_f1s)
                   
        

        # write logs to dataframe and csv file
        tmp = [epoch, optimizer.param_groups[0]["lr"], train_loss]

        # if you do validation to determine hyperparams
        if config.param_search:
            tmp += [
                val_loss,
                cls_acc,
                edit_score,
            ]
            tmp += segment_f1s
            tmp += [
                bound_acc,
                precision,
                recall,
                bound_f1s,
            ]

        tmp_df = pd.DataFrame(tmp, index=log.columns).T
        log = pd.concat([log, tmp_df], ignore_index=True)
        log.to_csv(os.path.join(result_path, "log.csv"))

        val_time = (time.time() - start) / 60


        eta_time = (config.max_epoch-epoch)*(train_time+val_time)
        writer.add_scalar('Loss/train', train_loss, epoch)
        if config.param_search:
            writer.add_scalar('Loss/validation', val_loss, epoch)
            writer.add_scalar('Metrics/cls_acc', cls_acc, epoch)
            writer.add_scalar('Metrics/edit_score', edit_score, epoch)
            writer.add_scalar('Metrics/f1s@0.1', segment_f1s[0], epoch)
            writer.add_scalar('Metrics/f1s@0.25', segment_f1s[1], epoch)
            writer.add_scalar('Metrics/f1s@0.5', segment_f1s[2], epoch)
            writer.add_scalar('Metrics/bound_acc', bound_acc, epoch)
            writer.add_scalar('Metrics/bound_f1s', bound_f1s, epoch)
        
        if config.param_search:

            print(
                'epoch: {}, lr: {:.4f}, train_time: {:.2f}min, train loss: {:.4f}, val_time: {:.2f}min, eta_time: {:.2f}min, \nval_loss: {:.4f}, acc: {:.2f}, edit: {:.2f}, F1@0.1: {:.2f}, F1@0.25: {:.2f}, F1@0.5: {:.2f}, bound_acc: {:.2f}, bound_f1: {:.2f}'
                .format(epoch, optimizer.param_groups[0]['lr'], train_time, train_loss, val_time, eta_time, val_loss, cls_acc, \
                edit_score, segment_f1s[0],segment_f1s[1], segment_f1s[2],bound_acc,bound_f1s)
            )

            with open(f'{result_path}/scores.txt', "a+") as file:
                file.write(
                    'epoch: {}, lr: {:.4f}, train_time: {:.2f}min, train loss: {:.4f}, val_time: {:.2f}min, eta_time: {:.2f}min, \nval_loss: {:.4f}, acc: {:.2f}, edit: {:.2f}, F1@0.1: {:.2f}, F1@0.25: {:.2f}, F1@0.5: {:.2f}, bound_acc: {:.2f}, bound_f1: {:.2f}\n'
                    .format(epoch, optimizer.param_groups[0]['lr'], train_time, train_loss, val_time, eta_time, val_loss, cls_acc, \
                    edit_score, segment_f1s[0],segment_f1s[1], segment_f1s[2],bound_acc,bound_f1s)
                )
        else:
            print(
                "epoch: {}\tlr: {:.4f}\ttrain loss: {:.4f}".format(
                    epoch, optimizer.param_groups[0]["lr"], train_loss
                )
            )


    print('\n---------------------------best_test_acc---------------------------\n')
    print('{}'.format(best_test_acc))
    print('\n---------------------------best_test_F1_10---------------------------\n')
    print('{}'.format(best_test_F1_10))
    print('\n---------------------------best_test_F1_50---------------------------\n')
    print('{}'.format(best_test_F1_50))
    print('\n---------------------------all_train_time---------------------------\n')
    print('all_train_time: {:.2f}min'.format((time.time() - start_start) / 60))

    with open(f'{result_path}/scores.txt', "a+") as file:
        file.write('\n---------------------------best_test_acc---------------------------\n')
        file.write('{}'.format(best_test_acc))
        file.write('\n---------------------------best_test_F1_10---------------------------\n')
        file.write('{}'.format(best_test_F1_10))
        file.write('\n---------------------------best_test_F1_50---------------------------\n')
        file.write('{}'.format(best_test_F1_50))
        file.write('\n---------------------------all_train_time---------------------------\n')
        file.write('all_train_time: {:.2f}min'.format((time.time() - start_start) / 60))



    best_test_acc = pd.DataFrame.from_dict(best_test_acc, orient='index').T
    best_test_F1_10 = pd.DataFrame.from_dict(best_test_F1_10, orient='index').T
    best_test_F1_50 = pd.DataFrame.from_dict(best_test_F1_50, orient='index').T
    log = pd.concat([log, best_test_acc], ignore_index=True)
    log = pd.concat([log, best_test_F1_10], ignore_index=True)
    log = pd.concat([log, best_test_F1_50], ignore_index=True)
    log.to_csv(os.path.join(result_path, 'log.csv'), index=False)

    print("Done!")


if __name__ == "__main__":
    main()
