import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
#from timm.utils import accuracy

from torch.nn.utils import clip_grad_norm_
from torch import optim
import pickle
from config import DATASET_CONFIG
from util.misc import get_next_run_number

import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from util.pos_embed import interpolate_pos_embed

#from model import PhysioModel
from model_normwear import NormWear
from dataset import LinearProbDataset,linprob_collate_fn
import torch.nn as nn
from tqdm import tqdm
from util.misc import freeze_model
from util.head import RegressionHead, ClassificationHead
from torch.utils.tensorboard import SummaryWriter

import util.misc as misc
from engine_linprob import train_one_epoch,evaluate
from util.misc import NativeScalerWithGradNormCount as NativeScaler

def parse_tuple(arg_str):
    # Remove parentheses if they exist
    if arg_str.startswith('(') and arg_str.endswith(')'):
        arg_str = arg_str[1:-1]
    # Split the string by comma and convert to a tuple of floats
    return tuple(map(float, arg_str.split(',')))

def get_args_parser():
    parser = argparse.ArgumentParser('PhysioModel linear probing', add_help=False)
    parser.add_argument('--batch_size', default=32, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--epochs', default=40, type=int)
    parser.add_argument('--warmup_epochs', type=int, default=4, metavar='N',
                        help='epochs to warmup LR')
    parser.add_argument('--y_range', default=None, type=parse_tuple)
    parser.add_argument('--accum_iter', default=1, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
    parser.add_argument('--remark', default='model_mae_t6f5',
                        help='model_remark')
    # Model parameters
    parser.add_argument('--model_weight_dir', default='../data/results', type=str,
                        help='path of model weight')
    parser.add_argument('--model_name', default='model_mae', type=str,
                        help='Name of model to train')
    parser.add_argument('--eval', default=False,type=bool,
                        help='Perform evaluation only')
    parser.add_argument('--checkpoint', default='../data/results/model_mae_checkpoint-140.pth', 
                        type=str,help='model checkpoint for evaluation')
    parser.add_argument('--task', default='reg', 
                        type=str,help='reg/cls')
    parser.add_argument('--log_dir', default='../data/results/log',
                        help='path where to tensorboard log')
    parser.add_argument('--new_size', default=(43,13),type=parse_tuple,
                        help='new number of patches')
    parser.add_argument('--use_meanpooling', default=0,type=int,
                        help='meanpooling cls fusion')
       

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay (default: 0 for linear probe following MoCo v1)')

    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=0.1, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')

    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

    parser.add_argument('--seed', default=42, type=int)

    # Dataset parameters
    parser.add_argument('--is_pretrain',default=False,type=bool)
    parser.add_argument('--data_path', default='data/pretrain', type=str,
                        help='dataset path')
    parser.add_argument('--ds_name', default='maus', type=str,
                        help='dataset name')
    
    parser.add_argument('--num_classes', type=int, default=1,
                    help='number of the classification types')
    parser.add_argument('--num_channel', default=1, type=int,
                        help='number of the input chennels')

    
    parser.add_argument('--output_dir', default='.././data/results',
                        help='path where to save, empty for no saving')

    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    
    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')
    

    return parser

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def load_dataset(args,ds_name="wesad"):
    if args.is_pretrain:
        fnames = os.listdir(os.path.join(args.data_path,ds_name))
        train_fnames, test_fnames = train_test_split(fnames, test_size=0.2, random_state=42)

        train_fnames = [os.path.join(args.data_path, ds_name, fname) for fname in train_fnames]
        test_fnames = [os.path.join(args.data_path, ds_name, fname) for fname in test_fnames]

    else:
        # Define pod and PVC paths
        f_path = f"data/downstream/{ds_name}/splits"
        pvc_path = f"../{f_path}"

        # Use pod path if it exists, otherwise use PVC path
        if not os.path.exists(f_path):
            f_path = pvc_path

        # Load the split data
        with open(f_path, 'rb') as f:
            split = pickle.load(f)

        # Adjust file names based on detected path
        if f_path == pvc_path:
            train_fnames = ["../../data/downstream" + fname.replace("\\", "/").lstrip("data") for fname in split["train_fnames"]]
            test_fnames = ["../../data/downstream" + fname.replace("\\", "/").lstrip("data") for fname in split["test_fnames"]]
        else:
            train_fnames = [fname.replace("\\", "/") for fname in split["train_fnames"]]
            test_fnames = [fname.replace("\\", "/") for fname in split["test_fnames"]]

        return train_fnames, test_fnames


class LinearProb(nn.Module):
    def __init__(self, backbone,
                 num_classes,num_channel,
                 embed_size=768,task='reg',
                 y_range=None):
        
        super().__init__()
        # init backbone
        self.backbone = backbone
        freeze_model(self.backbone)

        # mean pooling fuse
        self.fuse_w = nn.Parameter(torch.ones(1, 1, num_channel)/num_channel)
        self.head = nn.Linear(embed_size, num_classes)

        self.embed_size = embed_size

    def forward(self, x):
        '''Input
        x: bs x nvar, ch, L, F

        Output:
        cls: bs x num_class
        reg: bs
        '''
        
        with torch.no_grad():
            z = self.backbone.feature_extractor(x) # bs,nvar,L+1,E

        # mean pooling
        z = z.mean(dim=2) # bs,nvar,E
        z = z.permute(0, 2, 1) * self.fuse_w # bs, E, nvar

        x_out = self.head(torch.sum(z, dim=2))
        
        #x_out = self.head(z)

        return x_out, z
    

def main(args):
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True
    
    if args.log_dir is not None:
        path = os.path.join(args.log_dir, args.ds_name)
        os.makedirs(path, exist_ok=True)
        run_number = get_next_run_number(path)
        log_run_dir = os.path.join(path, f'run_{args.remark}_{run_number}')
        os.makedirs(log_run_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=log_run_dir)


    train_fnames, test_fnames = load_dataset(args,ds_name=args.ds_name)
    train_dataset = LinearProbDataset(train_fnames, task=args.task)
    test_dataset = LinearProbDataset(test_fnames,task=args.task)

    num_workers = 0 if device == 'cpu' else 4
    print('number of workers',num_workers)


    print("Number of Training Samples:", len(train_dataset))
    print("Number of Testing Samples:", len(test_dataset))

    if True:  # args.distributed:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )
        print("Sampler_train = %s" % str(sampler_train))
        if True:#args.dist_eval:
            if len(test_dataset) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
                      'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                test_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True)  # shuffle=True to reduce monitor bias
        else:
            sampler_val = torch.utils.data.SequentialSampler(test_dataset)
            
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train) # shuffle = True
        sampler_val = torch.utils.data.SequentialSampler(dataset_val) # shuffle = False
    

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, 
                                num_workers=num_workers,
                                sampler = sampler_train,
                                drop_last=True,pin_memory=True,
                                collate_fn = linprob_collate_fn,)
    
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, 
                                num_workers=num_workers,
                                drop_last=False,pin_memory=True,
                                sampler = sampler_val,
                                collate_fn = linprob_collate_fn)




    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    

    # Run a experiement multiple time to get standard deviation
    mult_exp_acc = []
    for run in range(3):
        print('Exp Number: ',run)
        max_accuracy = 0.0
        # Reload everytime
        print('Loading pre-trained checkpoint from',args.checkpoint)
        backbone = NormWear(img_size=(387,65),patch_size=(9,5),nvar=args.num_channel,
                                        is_pretrain=False,use_meanpooling=args.use_meanpooling,
                                        mask_prob=0,)
        checkpoint = torch.load(args.checkpoint,map_location='cpu')
        checkpoint_model = checkpoint['model']

        # Interpolating position embedding
        #interpolate_pos_embed(backbone, checkpoint_model,orig_size=(43,13),new_size=args.new_size)

        msg = backbone.load_state_dict(checkpoint_model, strict=False)
        print(msg)

        model = LinearProb(backbone,
                    num_classes=args.num_classes,
                    num_channel=args.num_channel,
                    task=args.task,
                    y_range=args.y_range)

        print("Model = %s" % str(model))
        n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print('number of training params : %.2f' % (n_parameters))
        model.to(device)
        
        print("lr: %.3e" % args.lr)

        if args.distributed:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            model_without_ddp = model.module
        
        optimizer =  torch.optim.AdamW(
            model.parameters(),
            lr=args.lr,
            weight_decay=args.weight_decay,
            )
        print(optimizer)
        loss_scaler = NativeScaler()

        criterion = torch.nn.L1Loss() if args.task == 'reg' else torch.nn.CrossEntropyLoss()
        print("criterion = %s" % str(criterion))

        # Individual experiement
        for epoch in tqdm(range(args.epochs)):
            train_stats = train_one_epoch(
                model,criterion,train_loader,
                optimizer,device,epoch,loss_scaler,
                max_norm=None,
                log_writer=log_writer,
                args=args,)
            
            # training accuracy
            test_stats = evaluate(args,test_loader, model, device,criterion)
            print(f"Accuracy of the network on the {len(test_loader)} test images: {test_stats['acc1']:.1f}%")
            max_accuracy = max(max_accuracy, test_stats["acc1"])
            print(f'Max accuracy: {max_accuracy:.2f}%')

            if log_writer is not None:
                    log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
                    log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)

        mult_exp_acc.append(max_accuracy)        

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    # Need mean accuracy and standard deviation of mult_exp_acc
    print(f'{args.ds_name} has mean accuracy of {np.mean(mult_exp_acc)} with standard deviation of {np.std(mult_exp_acc)}')
    print('Done !!')

    if args.output_dir:
        misc.save_model(
            args=args, model=model, model_without_ddp=model, optimizer=optimizer,
            loss_scaler=loss_scaler, epoch=epoch)


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    args.num_channel = DATASET_CONFIG[args.ds_name]["n_ch"]
    args.num_classes = DATASET_CONFIG[args.ds_name]["n_cl"]
    args.task = DATASET_CONFIG[args.ds_name]["task"]
    args.lr = DATASET_CONFIG[args.ds_name]["lr"]
    args.batch_size = DATASET_CONFIG[args.ds_name]["bs"]
    args.remark = args.remark + args.ds_name
    
    if "y_range" in DATASET_CONFIG[args.ds_name]:
        args.y_range = DATASET_CONFIG[args.ds_name]["y_range"]
    else:
        args.y_range = None 

    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    main(args)