import math
import torch
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from src.models import *


def build_model(args, logger, device):
    model = None
    
    if args.supervised:
        model_name = "Supervised"
    else:
        model_name = args.model_name

    if args.model_name == "SDMI":
        model = SDMI(args)
    elif args.model_name == "BYOL":
        model = BYOL(args)
    elif args.model_name == "SimSiam":
        model = SimSiam(args)
    elif args.model_name == "SimCLR":
        model = SimCLR(args)
    elif args.model_name == "MoCo":
        model = MoCo(args)
    elif args.model_name == "JMI":
        model = JMI(args)
    elif args.model_name == "SimSiam_SDMI":
        model = SimSiam_SDMI(args)
    elif args.model_name == "BarlowTwins":
        model = BarlowTwins(args)
    elif args.model_name == "VICReg":
        model = VICReg(args)
    elif args.supervised:
        model = build_encoder(args)
        model.loss_history = []
        model.evaluation_history = {}
    else:
        raise ValueError(f"Model {args.model_name} not recognized.")
    
    model.to(device)
    logger.info(f"{model_name} model created using {args.architecture}.")
    
    return model
    

def get_optimizer(args, model, learning_rate, weight_decay, momentum, evaluation=False):
    if args.model_name == "SDMI" and not evaluation:
        E_params = list(model.E_encoder.parameters()) + list(model.E_projector.parameters())
        M_params = list(model.M_encoder.parameters()) + list(model.M_projector.parameters())

        if args.optimizer == "SGD":
            optimizer_E = torch.optim.SGD(
                params=E_params,
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay,
            )

            optimizer_M = torch.optim.SGD(
                params=M_params,
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay,
            )

            optimizer = (optimizer_E, optimizer_M)

        elif args.optimizer == "Adam":
            optimizer_E = torch.optim.Adam(
                params=E_params,
                lr=args.initial_lr,
                weight_decay=weight_decay,
            )

            optimizer_M = torch.optim.Adam(
                params=M_params,
                lr=args.initial_lr,
                weight_decay=weight_decay,
            )
        
        else:
            raise ValueError(f"Optimizer {args.optimizer} not recognized.")

        optimizer = (optimizer_E, optimizer_M)

    else:
        if args.optimizer == "SGD":
            optimizer = torch.optim.SGD(
                params=model.parameters(),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay,
            )

        elif args.optimizer == "Adam":
            optimizer = torch.optim.Adam(
                params=model.parameters(),
                lr=args.initial_lr,
                weight_decay=weight_decay,
            )

        else:
            raise ValueError(f"Optimizer {args.optimizer} not recognized.")
    
    return optimizer


def adjust_lr(args, optimizer, current_epoch):
    if current_epoch < args.warmup_epochs:
        lr = args.warmup_initial_lr + (args.initial_lr - args.warmup_initial_lr) * (current_epoch / args.warmup_epochs)

    else:
        progress = (current_epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)
        lr = args.initial_lr * 0.5 * (1 + math.cos(math.pi * progress))

    def _set_lr(opt):
        for group in opt.param_groups:
            group["lr"] = lr

    if isinstance(optimizer, (list, tuple)):
        for opt in optimizer:
            _set_lr(opt)
    else:
        _set_lr(optimizer)


def plotter(feature_data, label=None, y_label=None, x_label='Epoch', color=None, integer_x=False, 
                  x_data=None, marker=None, figsize=(10, 6), existing_plot=False, final_plot=True,
                  storage_destination=None):
    
    x_values = x_data if x_data is not None else range(len(feature_data))

    if not existing_plot:
        plt.figure(figsize=figsize)

    plt.plot(x_values, feature_data, color=color, label=label, marker=marker)
    plt.xlabel(x_label)   
    plt.ylabel(y_label)
    plt.title(f'{y_label} over {x_label}')
    plt.grid(True)

    if label:
        plt.legend()

    if integer_x:
        plt.gca().xaxis.set_major_locator(mticker.MaxNLocator(integer=True))

    file_name = f'{y_label.lower().replace(" ", "_")}_over_{x_label.lower().replace(" ", "_")}.png'
    
    if storage_destination:
        save_path = os.path.join(storage_destination, file_name)
    else:
        save_path = file_name
    
    if not existing_plot or final_plot:
        plt.savefig(save_path)
        plt.close()


def graph_plotting_manager(args, run_paths, model):
    if args.model_name == "SDMI":
        graphs_to_plot = {
            'E-encoder Linear Probing Accuracy': {
                'feature': model.E_linear_probing_history,
                'plot_color': 'teal',
            },
            'M-encoder Linear Probing Accuracy': {
                'feature': model.M_linear_probing_history,
                'plot_color': 'orange',
            },
            'E-Step Total Loss': {
                'feature': model.E_loss_history,
                'plot_color': 'blue'
            },
            'M-Step Total Loss': {
                'feature': model.M_loss_history,
                'plot_color': 'red'
            }
        }

    else:
        if args.supervised:
            accuracy_title = 'Supervised Accuracy'
            accuracy_feature = model.evaluation_history
        else:
            accuracy_title = 'Linear Probing Accuracy'
            accuracy_feature = model.linear_probing_history
        
        graphs_to_plot = {
            accuracy_title: {
                'feature': accuracy_feature,
                'plot_color': 'teal',
            },
            'Total Loss': {
                'feature': model.loss_history,
                'plot_color': 'blue'
            }
        }
    
    for title, config in graphs_to_plot.items():
        if "Accuracy" in title:
            if len(config['feature']) == 0:
                x_data = []
            else:
                x_data = [int(key.split('-')[1]) for key in config['feature'].keys()]
                config['feature'] = list(config['feature'].values())
                
        else:
            x_data = list(range(len(config['feature'])))

        plotter(
            config['feature'],
            y_label=title,
            x_label=config.get('x_label', "Epoch"),
            color=config.get('plot_color', 'blue'),
            x_data=x_data,
            storage_destination=run_paths.graphs_directory
        )
