import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,"

from collections import OrderedDict
from glob import glob
import random
import numpy as np

import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import yaml

from albumentations.augmentations import transforms
from albumentations.augmentations import geometric

from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split
from torch.optim import lr_scheduler
from tqdm import tqdm
from albumentations import RandomRotate90, Resize
import losses

import archs
from dataset import Dataset

from metrics import iou_score, indicators

from utils import AverageMeter, str2bool

from tensorboardX import SummaryWriter

from thop import profile
from thop import clever_format

import shutil
import os
import subprocess
import time
from pdb import set_trace as st


ARCH_NAMES = archs.__all__
LOSS_NAMES = losses.__all__
LOSS_NAMES.append('BCEWithLogitsLoss')


def list_type(s):
    str_list = s.split(',')
    int_list = [int(a) for a in str_list]
    return int_list


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default=None,
                        help='model name: (default: arch+timestamp)')
    parser.add_argument('--epochs', default=400, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=8, type=int,
                        metavar='N', help='mini-batch size (default: 16)')

    parser.add_argument('--dataseed', default=2981, type=int,
                        help='')
    
    parser.add_argument('--arch', '-a', metavar='ARCH', default='UKAN')
    
    parser.add_argument('--deep_supervision', default=False, type=str2bool)
    parser.add_argument('--input_channels', default=3, type=int,
                        help='input channels')
    parser.add_argument('--num_classes', default=1, type=int,
                        help='number of classes')
    parser.add_argument('--input_w', default=256, type=int,
                        help='image width')
    parser.add_argument('--input_h', default=256, type=int,
                        help='image height')
    parser.add_argument('--input_list', type=list_type, default=[128, 160, 256])

    parser.add_argument('--loss', default='BCEDiceLoss',
                        choices=LOSS_NAMES,
                        help='loss: ' +
                        ' | '.join(LOSS_NAMES) +
                        ' (default: BCEDiceLoss)')
    
    parser.add_argument('--dataset', default='busi', help='dataset name')      
    parser.add_argument('--data_dir', default='./inputs', help='dataset dir')

    parser.add_argument('--output_dir', default='outputs', help='ouput dir')

    parser.add_argument('--optimizer', default='Adam',
                        choices=['Adam', 'SGD'],
                        help='loss: ' +
                        ' | '.join(['Adam', 'SGD']) +
                        ' (default: Adam)')

    parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float,
                        help='momentum')
    parser.add_argument('--weight_decay', default=1e-4, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', default=False, type=str2bool,
                        help='nesterov')

    parser.add_argument('--kan_lr', default=1e-2, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--kan_weight_decay', default=1e-4, type=float,
                        help='weight decay')

    parser.add_argument('--scheduler', default='CosineAnnealingLR',
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
    parser.add_argument('--min_lr', default=1e-5, type=float,
                        help='minimum learning rate')
    parser.add_argument('--factor', default=0.1, type=float)
    parser.add_argument('--patience', default=2, type=int)
    parser.add_argument('--milestones', default='1,2', type=str)
    parser.add_argument('--gamma', default=2/3, type=float)
    parser.add_argument('--early_stopping', default=-1, type=int,
                        metavar='N', help='early stopping (default: -1)')
    parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', )
    parser.add_argument('--num_workers', default=4, type=int)

    parser.add_argument('--no_kan', action='store_true')

    parser.add_argument('--monitor_memory', action='store_true', help='Whether to monitor and output GPU memory usage')
    parser.add_argument('--measure_speed', action='store_true', help='Whether to measure training speed (average of steps 10-29)')
    parser.add_argument('--measure_flops', action='store_true', help='Whether to measure training speed (average of steps 10-29)')

    config = parser.parse_args()

    return config


def train(config, train_loader, model, criterion, optimizer):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter()}

    model.train()

    if config['monitor_memory']:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

    if config['measure_speed']:
        step_counter = 0
        step_times = []
        speed_measured = False

    pbar = tqdm(total=len(train_loader))
    for input, target, _ in train_loader:
        if config['measure_speed'] and not speed_measured:
            step_counter += 1
            start_time = time.time()

        input = input.cuda()
        target = target.cuda()

        if config['deep_supervision']:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)

            iou, dice, _ = iou_score(outputs[-1], target)
            
        else:
            output = model(input)

            loss = criterion(output, target)
            iou, dice, _ = iou_score(output, target)
        if config['monitor_memory']:
            max_mem = torch.cuda.max_memory_allocated() / (1024 ** 3)
            print(f"[Train] Peak GPU Memory: {max_mem:.3f} GB")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if config['measure_speed'] and not speed_measured and step_counter >= 10:
            end_time = time.time()
            step_time = end_time - start_time
            step_times.append(step_time)
            if step_counter >= 29:
                avg_step_time = sum(step_times) / len(step_times)
                print(f"\nAverage training speed (steps 10-29): {avg_step_time:.6f} seconds per step")
                print(f"Training throughput: {avg_step_time:.2f} second per step/\n")
                speed_measured = True
                break

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
            ('lr', optimizer.param_groups[0]['lr']),
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()



    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg)])


def validate(config, val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter(),
                   'dice': AverageMeter()}

    model.eval()

    with torch.no_grad():
        pbar = tqdm(total=len(val_loader))
        for input, target, _ in val_loader:
            input = input.cuda()
            target = target.cuda()

            if config['deep_supervision']:
                outputs = model(input)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                iou, dice, _ = iou_score(outputs[-1], target)
            else:
                output = model(input)
                loss = criterion(output, target)
                iou, dice, _ = iou_score(output, target)

            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))

            postfix = OrderedDict([
                ('loss', avg_meters['loss'].avg),
                ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)
            ])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])

def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def main():
    seed_torch()
    config = vars(parse_args())

    exp_name = config.get('name')
    output_dir = config.get('output_dir')

    my_writer = SummaryWriter(f'{output_dir}/{exp_name}')

    if config['name'] is None:
        if config['deep_supervision']:
            config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
        else:
            config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
    
    os.makedirs(f'{output_dir}/{exp_name}', exist_ok=True)

    print('-' * 20)
    for key in config:
        print('%s: %s' % (key, config[key]))
    print('-' * 20)

    with open(f'{output_dir}/{exp_name}/config.yml', 'w') as f:
        yaml.dump(config, f)

    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

    model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision'], embed_dims=config['input_list'], no_kan=config['no_kan'])


    model = model.cuda()

    print("\n" + "="*80)
    print("Model Architecture:")
    print("="*80)
    print(model)
    print("="*80)
    
    all_modules = list(model.modules())
    total_modules = len(all_modules)
    
    trainable_layers = sum(1 for m in model.modules() 
                          if len(list(m.parameters())) > 0 and not isinstance(m, nn.ModuleList))
    
    def get_max_depth(module, depth=0):
        children = list(module.children())
        if len(children) == 0:
            return depth
        else:
            return max([get_max_depth(child, depth + 1) for child in children], default=depth)
    
    max_depth = get_max_depth(model)
    
    conv_layers = sum(1 for m in model.modules() if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)))
    linear_layers = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
    norm_layers = sum(1 for m in model.modules() if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm, nn.GroupNorm)))
    activation_layers = sum(1 for m in model.modules() if isinstance(m, (nn.ReLU, nn.SiLU, nn.GELU, nn.LeakyReLU, nn.Tanh, nn.Sigmoid)))
    
    print(f"\nModel Statistics:")
    print(f"  Total modules: {total_modules}")
    print(f"  Trainable layers (with parameters): {trainable_layers}")
    print(f"  Maximum depth: {max_depth}")
    print(f"  Convolutional layers: {conv_layers}")
    print(f"  Linear/FC layers: {linear_layers}")
    print(f"  Normalization layers: {norm_layers}")
    print(f"  Activation layers: {activation_layers}")
    print("="*80 + "\n")

    if config['measure_flops']:
        input_channels = config['input_channels']
        input_h, input_w = config['input_h'], config['input_w']
        dummy_input = torch.randn(1, input_channels, input_h, input_w).cuda()
        flops, params = profile(model, inputs=(dummy_input,))
        flops_m = flops / 1e6
        flops_formatted = f"{flops_m:.3f} M"
        params_formatted = clever_format(params, "%.3f")
        print(f"Model FLOPs: {flops_formatted}")
        print(f"Model Parameters: {params_formatted}")
        my_writer.add_text('Model Stats', f'FLOPs: {flops}, Params: {params}', global_step=0)

    param_groups = []

    kan_fc_params = []
    other_params = []

    for name, param in model.named_parameters():
        other_params.append(param)

    if kan_fc_params:
        param_groups.append(
            {'params': kan_fc_params, 'lr': config['kan_lr'], 'weight_decay': config['kan_weight_decay']})
    if other_params:
        param_groups.append({'params': other_params, 'lr': config['lr'], 'weight_decay': config['weight_decay']})

    print(f"Total parameter groups: {len(param_groups)}")
    print(f"Total parameters: {sum(p.numel() for p in kan_fc_params)+sum(p.numel() for p in other_params)}")
    print(f"KAN FC params count: {len(kan_fc_params)}")
    print(f"Total KAN FC parameters: {sum(p.numel() for p in kan_fc_params)}")
    print(f"Other params count: {len(other_params)}")
    print(f"Total Other params parameters: {sum(p.numel() for p in other_params)}")
    print(1)

    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(param_groups)


    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError
    print(1)

    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
    else:
        raise NotImplementedError
    print(1)

    shutil.copy2('train.py', f'{output_dir}/{exp_name}/')
    shutil.copy2('archs.py', f'{output_dir}/{exp_name}/')
    shutil.copy2('archs_MLkan.py', f'{output_dir}/{exp_name}/')
    shutil.copy2('kan_MLkan.py', f'{output_dir}/{exp_name}/')
    shutil.copy2('KAN_MLkan1.py', f'{output_dir}/{exp_name}/')
    shutil.copy2('ouv_conv.py', f'{output_dir}/{exp_name}/')
    shutil.copy2('kan.py', f'{output_dir}/{exp_name}/')
    shutil.copy2('KANConv.py', f'{output_dir}/{exp_name}/')

    dataset_name = config['dataset']
    img_ext = '.png'

    if dataset_name == 'busi':
        mask_ext = '_mask.png'
    elif dataset_name == 'glas':
        mask_ext = '.png'
    elif dataset_name == 'cvc':
        mask_ext = '.png'

    img_ids = sorted(glob(os.path.join(config['data_dir'], config['dataset'], 'images', '*' + img_ext)))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
    train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=config['dataseed'])

    train_transform = Compose([
        RandomRotate90(),
        geometric.transforms.Flip(),
        Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    val_transform = Compose([
        Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    train_dataset = Dataset(
        img_ids=train_img_ids,
        img_dir=os.path.join(config['data_dir'], config['dataset'], 'images'),
        mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'),
        img_ext=img_ext,
        mask_ext=mask_ext,
        num_classes=config['num_classes'],
        transform=train_transform)
    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join(config['data_dir'] ,config['dataset'], 'images'),
        mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'),
        img_ext=img_ext,
        mask_ext=mask_ext,
        num_classes=config['num_classes'],
        transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('val_loss', []),
        ('val_iou', []),
        ('val_dice', []),
    ])


    best_iou = 0
    best_dice= 0
    trigger = 0
    for epoch in range(config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))

        train_log = train(config, train_loader, model, criterion, optimizer)
        val_log = validate(config, val_loader, model, criterion)

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])

        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
              % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))

        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        log['val_dice'].append(val_log['dice'])

        pd.DataFrame(log).to_csv(f'{output_dir}/{exp_name}/log.csv', index=False)

        my_writer.add_scalar('train/loss', train_log['loss'], global_step=epoch)
        my_writer.add_scalar('train/iou', train_log['iou'], global_step=epoch)
        my_writer.add_scalar('val/loss', val_log['loss'], global_step=epoch)
        my_writer.add_scalar('val/iou', val_log['iou'], global_step=epoch)
        my_writer.add_scalar('val/dice', val_log['dice'], global_step=epoch)

        my_writer.add_scalar('val/best_iou_value', best_iou, global_step=epoch)
        my_writer.add_scalar('val/best_dice_value', best_dice, global_step=epoch)

        trigger += 1

        if val_log['iou'] > best_iou:
            old_models = glob(f'{output_dir}/{config["name"]}/model_*.pth')
            for m in old_models:
                os.remove(m)

            model_path = f'{output_dir}/{config["name"]}/model_{val_log["dice"]:.4f}.pth'
            torch.save(model.state_dict(), model_path)

            best_iou = val_log['iou']
            best_dice = val_log['dice']
            print("=> saved best model")
            print('IoU: %.4f' % best_iou)
            print('Dice: %.4f' % best_dice)
            trigger = 0

        if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
            print("=> early stopping")
            break

        torch.cuda.empty_cache()
    
if __name__ == '__main__':
    main()
