手写Spring核心原理MVC实现

阅读 74

2022-05-30

. 背景

本人从事后端刚满一年,主要从事java和python工作,准备在最近一年内多学习各框架的源码,实现技术突破,为了加深印象,本文将通过代码+图文+文字说明的形式来写一个自定义的spring核心控制器dispatcherServlet以及相关组件,以供后续学习。

2. 项目准备

1) 工具

idea2018, jdk1.8,tomcat 8.5

2) spring 核心原理

spring 通过容器并使用工厂模式来创建实例Bean,即控制反转,然后通过将被依赖的类装配到指定的类里实现调用,即我们常用的依赖注入, 通过扫描包的方式来管理所有需要的注解和类,Spring在容器启动的时候初始化好所有的Bean以及相关信息并存储起来,等请求的时候再获取出来,通过反射动态加载的形式调用Controller里对应的方法。

3. 知识储备

①java注解相关知识 ② java反射相关知识 ③servlet相关知识 ④spring相关知识

4. 搭建项目

准备好相关工具后,我就可以开始搭建项目,coding....,完整步骤如下:

1) 建立应用

2) 导入依赖

我在这里没有使用maven来管理依赖,如果没有用maven的话,那么就按照如下方式添加tomcat依赖:

3) 项目目录结构

项目的目录结构截图如下, src作为source_root,即根目录, 如下图

如果src不是根目录,那么可以设置:

4) 配置tomcat

指定名称

指定访问的context

指定发布的包:配置好上述的全部信息就可以正常启动服务器了!启动成功后,会自动调用 http://localhost:8080/spring/ 。

5) 建立项目遇到的问题以及解决方法

①配置好服务器后启动tomcat报错:not found for the web module

解决方法:添加web包到Facets里,在Aitifacts里添加发布到的war包 第一步,选择projet structures。

第二步,选择facets里的web,添加本项目,选择ok即可!

第三步,添加Artifacts, 选择本项目ok即可。

第四步,选择刚配置好的artifacts添加到 edit configuration的deployment里:

重新启动服务器即可解决上述问题!

②启动报错: Cannot start compilation: the output path is not specified for module "spring-study".Specify the output path in the Project Structure dialog.

解决方法: 在当前项目目录下新建一个out目录,指定out的路径即可,如下图:

配置好后,重新启动即可解决问题!

5. 自定义配置

1) 配置servlet

如果上述的东西都准备好了,那么在web.xml文件中定义<servlet-class>标签时不会出现红色,类所在的路径一定要正确, 同时定义<init-param>标签,指定配置文件所在的路径。

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns="http://xmlns.jcp.org/xml/ns/javaee"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_4_0.xsd"
         version="4.0">


    <display-name>spring web application</display-name>

    <servlet>
        <servlet-name>DispatcherServlet</servlet-name>
        <servlet-class>com.springframework.demo.servlet.MyDispatcherServlet</servlet-class>
        <init-param>
            <param-name>contextConfigLocation</param-name>
            <param-value>classpath*:application.properties</param-value>
        </init-param>
    </servlet>
    <servlet-mapping>
        <servlet-name>DispatcherServlet</servlet-name>
        <url-pattern>/*</url-pattern>
    </servlet-mapping>
</web-app>
复制代码

2) 自定义配置文件application.properties

application.properties配置文件内容为:

scan-package=com.springframework.demo
复制代码

6. 自定义Spring相关注解

注解名称功能描述作用范围
MyAutoWire装配bean属性FIELD上
MyController请求控制器类TYPE上
MyRequestMapping路由请求类TYPE或者方法METHOD上
MyRequestParam请求参数参数PARAMETER上
MyService标记处理业务逻辑的类类TYPE上

MyAutoWire

package com.springframework.demo.mydefine;


import java.lang.annotation.*;

/**
 * @author bingbing
 * @date 2020/12/23 0023 15:01
 */
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyAutoWire {
    String value() default "";
}

复制代码

MyController

package com.springframework.demo.mydefine;


import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author bingbing
 * @date 2020/12/23 0023 15:02
 */
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface MyController {
    String value() default "";
}

复制代码

MyRequestMapping

package com.springframework.demo.mydefine;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author bingbing
 * @date 2020/12/23 0023 15:08
 */
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface MyRequestMapping {
    String value() default "";
}

复制代码

MyRequestParam

package com.springframework.demo.mydefine;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author bingbing
 * @date 2020/12/23 0023 15:09
 */
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
public @interface MyRequestParam {
    String value() default "";
}

复制代码

MyService

package com.springframework.demo.mydefine;

import java.lang.annotation.*;

/**
 * @author bingbing
 * @date 2020/12/23 0023 15:03
 */
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyService {

    String value() default "";
}

复制代码

ActionController

package com.springframework.demo;

import com.springframework.demo.mydefine.MyAutoWire;
import com.springframework.demo.mydefine.MyController;
import com.springframework.demo.mydefine.MyRequestMapping;
import com.springframework.demo.mydefine.MyRequestParam;
import com.springframework.demo.service.impl.DemoServiceImpl;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

/**
 * @author bingbing
 * @date 2020/12/23 0023 15:10
 */

@MyController
@MyRequestMapping("/action")
public class ActionController {

    @MyAutoWire
    private DemoServiceImpl demoService;


    @MyRequestMapping("/query")
    public void querySomeThing(
            HttpServletRequest request,
            HttpServletResponse response,
            @MyRequestParam("id") Integer id,
            @MyRequestParam("username") String username) throws IOException {
        String result = demoService.read(id, username);
        response.getWriter().println(result);
    }
}

复制代码

业务逻辑接口

package com.springframework.demo.service;

/**
 * @author bingbing
 * @date 2020/12/23 0023 15:19
 */
public interface IService {

    String read(Integer id, String name);
}

复制代码

业务逻辑类

package com.springframework.demo.service.impl;

import com.springframework.demo.mydefine.MyService;
import com.springframework.demo.service.IService;

/**
 * @author bingbing
 * @date 2020/12/23 0023 15:18
 */
@MyService
public class DemoServiceImpl implements IService {
    @Override
    public String read(Integer id, String name) {
        System.out.println("id=" + id + ",username=" + name);
        String str = "coding is a good habit!";
        System.out.println(str);
        return str;
    }
}

复制代码

7. 核心代码详解

tomcat容器启动时,会调用我们重写的init(ServletConfig config)方法,该方法为GenericServlet抽象类里定义的方法,我们可以在此方法里实现初始化spring容器。

public void init(ServletConfig config) throws ServletException {
        this.config = config;
        this.init();
    }

    public void init() throws ServletException {
    }

复制代码

1) 扫描包下所有class

通过扫描com.springframework.demo包,我们可以得到所有类和接口对应的class文件,将这所有的Class文件对应的全限定名(报名.类名)暂存到一个Map里面

//解析.class文件, 通过注解来识别
    private void doScanner(String scanPackage) {
        System.out.println("开始扫描包!" + mapping);
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        File classDir = new File(url.getFile());
        for (File file : classDir.listFiles()) {
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName());
            } else if (!file.getName().endsWith(".class")) {
                continue;
            }
            String clazzName = (scanPackage + "." + file.getName().replace(".class", ""));
            if (!file.isDirectory()) {
                mapping.put(clazzName, null);
            }
        }
        System.out.println("扫描完毕!" + mapping);
    }

复制代码

2) 遍历所有class, 按照类型来分类处理

 mapping.put(url, method);
 mapping.put(beanName, instance);

3) 对Controller下装配的Bean进行强制赋予访问权限

如被autowire注解标记的Bean。

Class clazz = obejct.getClass();
                if (clazz.isAnnotationPresent(MyController.class)) {
                    Field[] fields = clazz.getDeclaredFields();
                    for (Field field : fields) {
                        if (!field.isAnnotationPresent(MyAutoWire.class)) {
                            continue;
                        }
                        MyAutoWire autoWire = field.getAnnotation(MyAutoWire.class);
                        String beanName = autoWire.value();
                        if ("".equals(beanName)) {
                            beanName = field.getType().getName();
                        }
                        //授予权限
                        field.setAccessible(true);
                        field.set(mapping.get(clazz.getName()), mapping.get(beanName));
                    }
                }
复制代码

field.set()方法设置的2个object:

4) 分析Spring容器里的mapping

通过上述代码我们可以发现,spring容器通过hashmap存储controller、service、url请求等组件信息,初始化完毕后 ,得到的mapping里包含了10个键值对,分别对应的功能如下:

键|路径值|对象描述
com.springframework.demo.service.IServicecom.springframework.demo.service.impl.DemoServiceImpl@1b4ecd8b装配接口Iservice以及实现类信息
com.springframework.demo.service.impl.DemoServiceImplcom.springframework.demo.service.impl.DemoServiceImpl@1b4ecd8b装配实现类对象信息
com.springframework.demo.mydefine.MyControllernull装配controller注解组件
/action/querycom.springframework.demo.ActionController.querySomeThing(javax.servlet.http.HttpServletRequest,javax.servlet.http.HttpServletResponse,java.lang.Integer,java.lang.String) throws java.io.IOException装配url, 以及url请求所对应的method信息
com.springframework.demo.ActionControllercom.springframework.demo.ActionController@39613cd0}处理请求的类信息
map:{
com.springframework.demo.service.IService=com.springframework.demo.service.impl.DemoServiceImpl@1b4ecd8b, com.springframework.demo.service.impl.DemoServiceImpl=com.springframework.demo.service.impl.DemoServiceImpl@1b4ecd8b, 
com.springframework.demo.mydefine.MyController=null, com.springframework.demo.mydefine.MyService=null, 
/action/query=public void com.springframework.demo.ActionController.querySomeThing(javax.servlet.http.HttpServletRequest,javax.servlet.http.HttpServletResponse,java.lang.Integer,java.lang.String) throws java.io.IOException, 
com.springframework.demo.mydefine.MyAutoWire=null, com.springframework.demo.servlet.MyDispatcherServlet=null, com.springframework.demo.mydefine.MyRequestParam=null, com.springframework.demo.mydefine.MyRequestMapping=null, com.springframework.demo.ActionController=com.springframework.demo.ActionController@39613cd0}

复制代码

5)处理请求

访问 http://localhost:8080/spring/action/query?id=1&username=bingbing 将get请求转到doPost()上,从mapping里获取到url对应的method, 通过method反射执行ActionController里的 querySomeThing( HttpServletRequest request, HttpServletResponse response, @MyRequestParam("id") Integer id, @MyRequestParam("username") String username) 方法, 执行完毕后,使用response给出浏览器响应即可!

@Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        System.out.println("请求成功!");
        this.doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws RuntimeException {
        //处理并且分发请求
        try {
            doDispatcher(req, resp);
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (InvocationTargetException e) {
            System.out.println("InvocationTargetException");
        } catch (IllegalAccessException e) {
            System.out.println("非法访问!");
        }
    }

    private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");
        System.out.println("url:" + url);
        if (!this.mapping.containsKey(url)) {
            resp.getWriter().println("404 NOT FOUND");
        }
        //调用url里对应的方法
        Method method = (Method) mapping.get(url);
        Map<String, String[]> parameterMap = req.getParameterMap();
        method.invoke(this.mapping.get(method.getDeclaringClass().getName()), new Object[]{req, resp, new Integer(parameterMap.get("id")[0]), parameterMap.get("username")[0]});
    }
复制代码

被执行的方法:

@MyRequestMapping("/query")
    public void querySomeThing(
            HttpServletRequest request,
            HttpServletResponse response,
            @MyRequestParam("id") Integer id,
            @MyRequestParam("username") String username) throws IOException {
        String result = demoService.read(id, username);
        response.getWriter().println(result);
    }

复制代码

6) 中心控制器MyDispatcherServlet 完整代码

package com.springframework.demo.servlet;

import com.springframework.demo.mydefine.MyAutoWire;
import com.springframework.demo.mydefine.MyController;
import com.springframework.demo.mydefine.MyRequestMapping;
import com.springframework.demo.mydefine.MyService;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.*;
import java.io.InputStream;

/**
 * 自定义DispatchServlet
 *
 * @author Administrator
 */
public class MyDispatcherServlet extends HttpServlet {


    /**
     * 用来存放bean或请求的信息
     */
    private Map<String, Object> mapping = new HashMap<>();


    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        System.out.println("请求成功!");
        this.doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws RuntimeException {
        //处理并且分发请求
        try {
            doDispatcher(req, resp);
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (InvocationTargetException e) {
            e.printStackTrace();
            System.out.println("InvocationTargetException");
        } catch (IllegalAccessException e) {
            System.out.println("非法访问!");
        }
    }

    //动态加载参数, 可按照参数的顺序来绑定
    private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");
        System.out.println("url:" + url);
        if (!this.mapping.containsKey(url)) {
            resp.getWriter().println("404 NOT FOUND");
        }
        //调用url里对应的方法
        Method method = (Method) mapping.get(url);
        Map<String, String[]> parameterMap = req.getParameterMap();
        Object[] params = new Object[parameterMap.size() + 2];
        params[0] = req;
        params[1] = resp;
        //获取方法的参数类型
        Class<?>[] cls = method.getParameterTypes();
        int index = 2;
        //  index=2  cls=2 ,index=3 ,cls=3
        for (Map.Entry<String, String[]> s : parameterMap.entrySet()) {
            if (cls[index].getName() == "java.lang.Integer") {
                params[index] = new Integer(String.valueOf(s.getValue()[0]));
            } else if (cls[index].getName() == "java.lang.String") {
                params[index] = String.valueOf(s.getValue()[0]);
            }
            index++;
        }
        method.invoke(mapping.get(method.getDeclaringClass().getName()), params);
    }

    @Override
    public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException {
        super.service(req, res);
    }

    //扫描所有的class
    private void doScanner(String scanPackage) {
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        File rootDir = new File(url.getFile());
        for (File file : rootDir.listFiles()) {
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName());
            } else if (!file.getName().endsWith(".class")) {
                continue;
            }

            if (!file.isDirectory()) {
                String className = scanPackage + "." + file.getName().replaceAll(".class", "");
                mapping.put(className, null);
            }
        }
    }

    @Override
    public void init(ServletConfig config) throws ServletException {
        System.out.println("开始初始化容器....");
        InputStream is = null;
        // 1. 通过classLoader方法getResourceAsStream()获取到配置文件对象
        try {
            Properties configText = new Properties();
            String configName = config.getInitParameter("contextConfigLocation");
            configName = configName.substring(configName.indexOf(":") + 1);
            is = this.getClass().getClassLoader().getResourceAsStream(configName);
            configText.load(is);
            String packageName = configText.getProperty("scan-package");
            // 2. 扫描包,装配到mapping里, key 为 class的全限定名,value为null
            doScanner(packageName);
            // 3. 遍历map的key,设置所有的controller和service。
            for (String clazzName : mapping.keySet()) {
                if (!clazzName.contains(".")) {
                    continue;
                }
                Class<?> clazz = Class.forName(clazzName);
                // 判断Class被哪个注解标记
                if (clazz.isAnnotationPresent(MyController.class)) {
                    // 被controller注解标记的类
                    mapping.put(clazzName, clazz.newInstance());
                    // 如果被requestMapping注解标记
                    String baseurl = "";
                    if (clazz.isAnnotationPresent(MyRequestMapping.class)) {
                        MyRequestMapping requestMapping = clazz.getAnnotation(MyRequestMapping.class);
                        baseurl = requestMapping.value();
                    }
                    // 获取controller类下的所有methods
                    Method[] methods = clazz.getMethods();
                    for (Method method : methods) {
                        if (!method.isAnnotationPresent(MyRequestMapping.class)) {
                            continue;
                        }
                        // 获取RequestMapping 对象
                        MyRequestMapping requestMapping = method.getAnnotation(MyRequestMapping.class);
                        String url = baseurl + requestMapping.value();
                        // 解释了url不能重复的原因
                        mapping.put(url, method);
                    }

                } else if (clazz.isAnnotationPresent(MyService.class)) {
                    // 被MyService注解标记的类, 装配接口实现类和clazz,对map的key重新赋值
                    MyService myService = clazz.getAnnotation(MyService.class);
                    String beanName = myService.value();
                    if ("".equals(beanName)) {
                        beanName = clazz.getName();
                    }
                    Object obj = clazz.newInstance();
                    mapping.put(beanName, obj);
                    // 重新装配 Service类下的所有接口
                    for (Class<?> cls : clazz.getInterfaces()) {
                        mapping.put(cls.getName(), obj);
                    }
                } else {
                    continue;
                }

            }
            // 对类下的bean 进行授权可访问
            for (Object obj : mapping.values()) {
                if (obj == null) {
                    continue;
                }
                Class<?> clz = obj.getClass();

                Field[] fields = clz.getDeclaredFields();
                for (Field field : fields) {
                    if (!field.isAnnotationPresent(MyAutoWire.class)) {
                        continue;
                    }
                    MyAutoWire myAutoWire = field.getAnnotation(MyAutoWire.class);
                    String beanName = myAutoWire.value();
                    if ("".equals(beanName)) {
                        // bean名为属性的类型的全限定名
                        beanName = field.getType().getName();
                    }
                    field.setAccessible(true);
                    field.set(mapping.get(clz.getName()), mapping.get(beanName));
                }
            }
            System.out.println("扫描完毕!");
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println(e);
        } finally {
            try {
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
}

复制代码

8. 运行效果展示

访问 http://localhost:8080/spring/action/query?id=1&username=bingbing , 需要带上id和username参数,出现如下效果,表示spring的核心原理精简版就实现了。

9. 代码优化

1)method在使用反射执行的时候,不能动态的绑定方法参数以及参数类型。

我们可以从上述doDispatcher()方法内,使用invoke()方法执行的时候 method.invoke(this.mapping.get(method.getDeclaringClass().getName()), new Object[]{req, resp, new Integer(parameterMap.get("id")[0]), parameterMap.get("username")[0]}); 需要指定参数id和username,此种传参方式传参就是比较固定死板, 因此我想了一个办法是从method里获取到所有的参数类型,再遍历parameterMap的时候根据下标所在位置对应的元素所在类型进行判断是否为cls里面对应的类型。 优化代码如下:

//动态加载参数, 可按照参数的顺序来绑定
    private void doDispatcher(HttpServletRequest req, HttpServletResponse resp) throws IOException, InvocationTargetException, IllegalAccessException {
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replace(contextPath, "").replaceAll("/+", "/");
        System.out.println("url:" + url);
        if (!this.mapping.containsKey(url)) {
            resp.getWriter().println("404 NOT FOUND");
        }
        //调用url里对应的方法
        Method method = (Method) mapping.get(url);
        Map<String, String[]> parameterMap = req.getParameterMap();
        Object[] params = new Object[parameterMap.size() + 2];
        params[0] = req;
        params[1] = resp;
        //获取方法的参数类型
        Class<?>[] cls = method.getParameterTypes();
        int index = 2;
        //  index=2  cls=2 ,index=3 ,cls=3
        for (Map.Entry<String, String[]> s : parameterMap.entrySet()) {
            if (cls[index].getName() == "java.lang.Integer") {
                params[index] = new Integer(String.valueOf(s.getValue()[0]));
            } else if (cls[index].getName() == "java.lang.String") {
                params[index] = String.valueOf(s.getValue()[0]);
            }
            index++;
        }
        method.invoke(this.mapping.get(method.getDeclaringClass().getName()), params);
    }

复制代码

改进点: 与之前方式不同的是在调用方法时不需要知道参数的名称来做指定参数类型的转换,变的更灵活些。

精彩评论(0)

0 0 举报